Merge pull request #290 from red-hat-storage/sync_us--devel

Syncing latest changes from upstream devel for ceph-csi
This commit is contained in:
openshift-merge-bot[bot] 2024-04-11 08:23:51 +00:00 committed by GitHub
commit 8e62c9face
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
102 changed files with 1560 additions and 707 deletions

View File

@ -31,6 +31,7 @@ rules:
- - build
- cephfs
- ci
- csi-addons
- cleanup
- deploy
- doc

View File

@ -26,7 +26,7 @@ GO111MODULE=on
COMMITLINT_VERSION=latest
# static checks and linters
GOLANGCI_VERSION=v1.54.1
GOLANGCI_VERSION=v1.57.2
# external snapshotter version
# Refer: https://github.com/kubernetes-csi/external-snapshotter/releases
@ -41,7 +41,7 @@ SNAPSHOT_VERSION=v7.0.1
HELM_SCRIPT=https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3
# helm chart generation, testing and publishing
HELM_VERSION=v3.14.1
HELM_VERSION=v3.14.3
# minikube settings
MINIKUBE_VERSION=v1.32.0

View File

@ -23,7 +23,7 @@ import (
"k8s.io/kubernetes/test/e2e/framework"
)
// #nosec because of the word `Secret`
//nolint:gosec // secret for test
const (
// ceph user names.
keyringRBDProvisionerUsername = "cephcsi-rbd-provisioner"
@ -110,7 +110,7 @@ func createCephUser(f *framework.Framework, user string, caps []string) (string,
}
func deleteCephUser(f *framework.Framework, user string) error {
cmd := fmt.Sprintf("ceph auth del client.%s", user)
cmd := "ceph auth del client." + user
_, _, err := execCommandInToolBoxPod(f, cmd, rookNamespace)
return err

View File

@ -23,7 +23,7 @@ import (
"sync"
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions
. "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -735,7 +735,7 @@ var _ = Describe(cephfsType, func() {
framework.Failf("failed to create PVC and Deployment: %v", err)
}
deplPods, err := listPods(f, depl.Namespace, &metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", depl.Labels["app"]),
LabelSelector: "app=" + depl.Labels["app"],
})
if err != nil {
framework.Failf("failed to list pods for Deployment: %v", err)
@ -744,7 +744,7 @@ var _ = Describe(cephfsType, func() {
doStat := func(podName string) (string, error) {
_, stdErr, execErr := execCommandInContainerByPodName(
f,
fmt.Sprintf("stat %s", depl.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath),
"stat "+depl.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath,
depl.Namespace,
podName,
depl.Spec.Template.Spec.Containers[0].Name,
@ -808,7 +808,7 @@ var _ = Describe(cephfsType, func() {
}
// List Deployment's pods again to get name of the new pod.
deplPods, err = listPods(f, depl.Namespace, &metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", depl.Labels["app"]),
LabelSelector: "app=" + depl.Labels["app"],
})
if err != nil {
framework.Failf("failed to list pods for Deployment: %v", err)
@ -1074,13 +1074,13 @@ var _ = Describe(cephfsType, func() {
}
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' >"+filePath,
app.Namespace,
&opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -2407,13 +2407,13 @@ var _ = Describe(cephfsType, func() {
}
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' > "+filePath,
app.Namespace,
&opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)

View File

@ -338,7 +338,7 @@ func getSnapName(snapNamespace, snapName string) (string, error) {
}
snapIDRegex := regexp.MustCompile(`(\w+\-?){5}$`)
snapID := snapIDRegex.FindString(*sc.Status.SnapshotHandle)
snapshotName := fmt.Sprintf("csi-snap-%s", snapID)
snapshotName := "csi-snap-" + snapID
framework.Logf("snapshotName= %s", snapshotName)
return snapshotName, nil
@ -392,10 +392,10 @@ func validateEncryptedCephfs(f *framework.Framework, pvName, appName string) err
LabelSelector: selector,
}
cmd := fmt.Sprintf("getfattr --name=ceph.fscrypt.auth --only-values %s", volumeMountPath)
cmd := "getfattr --name=ceph.fscrypt.auth --only-values " + volumeMountPath
_, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt)
if err != nil {
cmd = fmt.Sprintf("getfattr --recursive --dump %s", volumeMountPath)
cmd = "getfattr --recursive --dump " + volumeMountPath
stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt)
if listErr == nil {
return fmt.Errorf("error checking for cephfs fscrypt xattr on %q. listing: %s %s",

View File

@ -21,7 +21,7 @@ import (
"fmt"
"strings"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions
. "github.com/onsi/gomega"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/kubernetes/test/e2e/framework"

View File

@ -95,8 +95,8 @@ func replaceNamespaceInTemplate(filePath string) (string, error) {
}
// template can contain "default" as namespace, with or without ".
templ := strings.ReplaceAll(string(read), "namespace: default", fmt.Sprintf("namespace: %s", cephCSINamespace))
templ = strings.ReplaceAll(templ, "namespace: \"default\"", fmt.Sprintf("namespace: %s", cephCSINamespace))
templ := strings.ReplaceAll(string(read), "namespace: default", "namespace: "+cephCSINamespace)
templ = strings.ReplaceAll(templ, "namespace: \"default\"", "namespace: "+cephCSINamespace)
return templ, nil
}

View File

@ -24,7 +24,7 @@ import (
"time"
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions
. "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -603,13 +603,13 @@ var _ = Describe("nfs", func() {
}
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' > "+filePath,
app.Namespace,
&opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)

View File

@ -382,7 +382,7 @@ func waitForPodInRunningState(name, ns string, c kubernetes.Interface, t int, ex
case v1.PodPending:
if expectedError != "" {
events, err := c.CoreV1().Events(ns).List(ctx, metav1.ListOptions{
FieldSelector: fmt.Sprintf("involvedObject.name=%s", name),
FieldSelector: "involvedObject.name=" + name,
})
if err != nil {
return false, err
@ -452,7 +452,7 @@ func deletePodWithLabel(label, ns string, skipNotFound bool) error {
// calculateSHA512sum returns the sha512sum of a file inside a pod.
func calculateSHA512sum(f *framework.Framework, app *v1.Pod, filePath string, opt *metav1.ListOptions) (string, error) {
cmd := fmt.Sprintf("sha512sum %s", filePath)
cmd := "sha512sum " + filePath
sha512sumOut, stdErr, err := execCommandInPod(f, cmd, app.Namespace, opt)
if err != nil {
return "", err

View File

@ -19,12 +19,13 @@ package e2e
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/ceph/ceph-csi/internal/util"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions
. "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
@ -536,7 +537,7 @@ var _ = Describe("RBD", func() {
})
By("reattach the old PV to a new PVC and check if PVC metadata is updated on RBD image", func() {
reattachPVCNamespace := fmt.Sprintf("%s-2", f.Namespace.Name)
reattachPVCNamespace := f.Namespace.Name + "-2"
pvc, err := loadPVC(pvcPath)
if err != nil {
framework.Failf("failed to load PVC: %v", err)
@ -1512,7 +1513,7 @@ var _ = Describe("RBD", func() {
cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath)
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
podList, err := e2epod.PodClientNS(f, app.Namespace).List(context.TODO(), opt)
if err != nil {
@ -1630,10 +1631,10 @@ var _ = Describe("RBD", func() {
validateOmapCount(f, 2, rbdType, defaultRBDPool, volumesType)
filePath := appClone.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
cmd := fmt.Sprintf("echo 'Hello World' > %s", filePath)
cmd := "echo 'Hello World' > " + filePath
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appClone.Name),
LabelSelector: "app=" + appClone.Name,
}
podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt)
if err != nil {
@ -1766,7 +1767,7 @@ var _ = Describe("RBD", func() {
cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath)
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appClone.Name),
LabelSelector: "app=" + appClone.Name,
}
podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt)
if err != nil {
@ -1862,14 +1863,14 @@ var _ = Describe("RBD", func() {
}
appOpt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
// TODO: Remove this once we ensure that rbd-nbd can sync data
// from Filesystem layer to backend rbd image as part of its
// detach or SIGTERM signal handler
_, stdErr, err := execCommandInPod(
f,
fmt.Sprintf("sync %s", app.Spec.Containers[0].VolumeMounts[0].MountPath),
"sync "+app.Spec.Containers[0].VolumeMounts[0].MountPath,
app.Namespace,
&appOpt)
if err != nil || stdErr != "" {
@ -1956,7 +1957,7 @@ var _ = Describe("RBD", func() {
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr, err = execCommandInPod(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' > "+filePath,
app.Namespace,
&appOpt)
if err != nil || stdErr != "" {
@ -3331,13 +3332,13 @@ var _ = Describe("RBD", func() {
for i := 0; i < totalCount; i++ {
name := fmt.Sprintf("%s%d", f.UniqueName, i)
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", name),
LabelSelector: "app=" + name,
}
filePath := appClone.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' > "+filePath,
appClone.Namespace,
&opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -3961,13 +3962,13 @@ var _ = Describe("RBD", func() {
validateOmapCount(f, 1, rbdType, defaultRBDPool, volumesType)
opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name),
LabelSelector: "app=" + app.Name,
}
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail(
f,
fmt.Sprintf("echo 'Hello World' > %s", filePath),
"echo 'Hello World' > "+filePath,
app.Namespace,
&opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -4350,9 +4351,9 @@ var _ = Describe("RBD", func() {
defaultSCName,
nil,
map[string]string{
"stripeUnit": fmt.Sprintf("%d", stripeUnit),
"stripeCount": fmt.Sprintf("%d", stripeCount),
"objectSize": fmt.Sprintf("%d", objectSize),
"stripeUnit": strconv.Itoa(stripeUnit),
"stripeCount": strconv.Itoa(stripeCount),
"objectSize": strconv.Itoa(objectSize),
},
deletePolicy)
if err != nil {

View File

@ -118,7 +118,7 @@ func createRBDStorageClass(
scOptions, parameters map[string]string,
policy v1.PersistentVolumeReclaimPolicy,
) error {
scPath := fmt.Sprintf("%s/%s", rbdExamplePath, "storageclass.yaml")
scPath := rbdExamplePath + "/" + "storageclass.yaml"
sc, err := getStorageClass(scPath)
if err != nil {
return fmt.Errorf("failed to get sc: %w", err)
@ -184,7 +184,7 @@ func createRBDStorageClass(
func createRadosNamespace(f *framework.Framework) error {
stdOut, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace ls --pool=%s", defaultRBDPool), rookNamespace)
"rbd namespace ls --pool="+defaultRBDPool, rookNamespace)
if err != nil {
return err
}
@ -193,7 +193,7 @@ func createRadosNamespace(f *framework.Framework) error {
}
if !strings.Contains(stdOut, radosNamespace) {
_, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace create %s", rbdOptions(defaultRBDPool)), rookNamespace)
"rbd namespace create "+rbdOptions(defaultRBDPool), rookNamespace)
if err != nil {
return err
}
@ -202,7 +202,7 @@ func createRadosNamespace(f *framework.Framework) error {
}
}
stdOut, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace ls --pool=%s", rbdTopologyPool), rookNamespace)
"rbd namespace ls --pool="+rbdTopologyPool, rookNamespace)
if err != nil {
return err
}
@ -212,7 +212,7 @@ func createRadosNamespace(f *framework.Framework) error {
if !strings.Contains(stdOut, radosNamespace) {
_, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace create %s", rbdOptions(rbdTopologyPool)), rookNamespace)
"rbd namespace create "+rbdOptions(rbdTopologyPool), rookNamespace)
if err != nil {
return err
}
@ -269,7 +269,7 @@ func getImageInfoFromPVC(pvcNamespace, pvcName string, f *framework.Framework) (
imageData = imageInfoFromPVC{
imageID: imageID,
imageName: fmt.Sprintf("csi-vol-%s", imageID),
imageName: "csi-vol-" + imageID,
csiVolumeHandle: pv.Spec.CSI.VolumeHandle,
pvName: pv.Name,
}
@ -671,7 +671,7 @@ func validateEncryptedFilesystem(f *framework.Framework, rbdImageSpec, pvName, a
cmd := fmt.Sprintf("lsattr -la %s | grep -E '%s/.\\s+Encrypted'", volumeMountPath, volumeMountPath)
_, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt)
if err != nil {
cmd = fmt.Sprintf("lsattr -lRa %s", volumeMountPath)
cmd = "lsattr -lRa " + volumeMountPath
stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt)
if listErr == nil {
return fmt.Errorf("error checking file encrypted attribute of %q. listing filesystem+attrs: %s %s",
@ -697,7 +697,7 @@ func listRBDImages(f *framework.Framework, pool string) ([]string, error) {
var imgInfos []string
stdout, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd ls --format=json %s", rbdOptions(pool)), rookNamespace)
"rbd ls --format=json "+rbdOptions(pool), rookNamespace)
if err != nil {
return imgInfos, err
}
@ -744,7 +744,7 @@ type rbdDuImageList struct {
// getRbdDu runs 'rbd du' on the RBD image and returns a rbdDuImage struct with
// the result.
//
//nolint:deadcode,unused // required for reclaimspace e2e.
//nolint:unused // Unused code will be used in future.
func getRbdDu(f *framework.Framework, pvc *v1.PersistentVolumeClaim) (*rbdDuImage, error) {
rdil := rbdDuImageList{}
@ -778,7 +778,7 @@ func getRbdDu(f *framework.Framework, pvc *v1.PersistentVolumeClaim) (*rbdDuImag
// take up any space anymore. This can be used to verify that an empty, but
// allocated (with zerofill) extents have been released.
//
//nolint:deadcode,unused // required for reclaimspace e2e.
//nolint:unused // Unused code will be used in future.
func sparsifyBackingRBDImage(f *framework.Framework, pvc *v1.PersistentVolumeClaim) error {
imageData, err := getImageInfoFromPVC(pvc.Namespace, pvc.Name, f)
if err != nil {
@ -802,7 +802,7 @@ func deletePool(name string, cephFS bool, f *framework.Framework) error {
// --yes-i-really-mean-it
// ceph osd pool delete myfs-replicated myfs-replicated
// --yes-i-really-mean-it
cmds = append(cmds, fmt.Sprintf("ceph fs fail %s", name),
cmds = append(cmds, "ceph fs fail "+name,
fmt.Sprintf("ceph fs rm %s --yes-i-really-mean-it", name),
fmt.Sprintf("ceph osd pool delete %s-metadata %s-metadata --yes-i-really-really-mean-it", name, name),
fmt.Sprintf("ceph osd pool delete %s-replicated %s-replicated --yes-i-really-really-mean-it", name, name))
@ -850,7 +850,7 @@ func getPVCImageInfoInPool(f *framework.Framework, pvc *v1.PersistentVolumeClaim
}
stdOut, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd info %s", imageSpec(pool, imageData.imageName)), rookNamespace)
"rbd info "+imageSpec(pool, imageData.imageName), rookNamespace)
if err != nil {
return "", err
}
@ -1021,7 +1021,7 @@ func listRBDImagesInTrash(f *framework.Framework, poolName string) ([]trashInfo,
var trashInfos []trashInfo
stdout, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd trash ls --format=json %s", rbdOptions(poolName)), rookNamespace)
"rbd trash ls --format=json "+rbdOptions(poolName), rookNamespace)
if err != nil {
return trashInfos, err
}

View File

@ -22,7 +22,7 @@ import (
"strings"
"time"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions
. "github.com/onsi/gomega"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -179,7 +179,7 @@ func getDirSizeCheckCmd(dirPath string) string {
}
func getDeviceSizeCheckCmd(devPath string) string {
return fmt.Sprintf("blockdev --getsize64 %s", devPath)
return "blockdev --getsize64 " + devPath
}
func checkAppMntSize(f *framework.Framework, opt *metav1.ListOptions, size, cmd, ns string, t int) error {

View File

@ -23,7 +23,7 @@ import (
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
snapclient "github.com/kubernetes-csi/external-snapshotter/client/v7/clientset/versioned/typed/volumesnapshot/v1"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions
. "github.com/onsi/gomega"
v1 "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"

View File

@ -23,7 +23,7 @@ import (
"path/filepath"
"strings"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions
. "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -234,7 +234,7 @@ var _ = Describe("CephFS Upgrade Testing", func() {
}
// force an immediate write of all cached data to disk.
_, stdErr = execCommandInPodAndAllowFail(f, fmt.Sprintf("sync %s", filePath), app.Namespace, &opt)
_, stdErr = execCommandInPodAndAllowFail(f, "sync "+filePath, app.Namespace, &opt)
if stdErr != "" {
framework.Failf("failed to sync data to a disk %s", stdErr)
}

View File

@ -23,7 +23,7 @@ import (
"path/filepath"
"strings"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions
. "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -239,13 +239,13 @@ var _ = Describe("RBD Upgrade Testing", func() {
}
// force an immediate write of all cached data to disk.
_, stdErr = execCommandInPodAndAllowFail(f, fmt.Sprintf("sync %s", filePath), app.Namespace, &opt)
_, stdErr = execCommandInPodAndAllowFail(f, "sync "+filePath, app.Namespace, &opt)
if stdErr != "" {
framework.Failf("failed to sync data to a disk %s", stdErr)
}
opt = metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appLabel),
LabelSelector: "app=" + appLabel,
}
framework.Logf("Calculating checksum of %s", filePath)
checkSum, err = calculateSHA512sum(f, app, filePath, &opt)

View File

@ -183,9 +183,9 @@ func validateOmapCount(f *framework.Framework, count int, driver, pool, mode str
{
volumeMode: volumesType,
driverType: rbdType,
radosLsCmd: fmt.Sprintf("rados ls %s", rbdOptions(pool)),
radosLsCmd: "rados ls " + rbdOptions(pool),
radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.volume.", rbdOptions(pool)),
radosLsKeysCmd: fmt.Sprintf("rados listomapkeys csi.volumes.default %s", rbdOptions(pool)),
radosLsKeysCmd: "rados listomapkeys csi.volumes.default " + rbdOptions(pool),
radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.volumes.default %s | wc -l", rbdOptions(pool)),
},
{
@ -201,9 +201,9 @@ func validateOmapCount(f *framework.Framework, count int, driver, pool, mode str
{
volumeMode: snapsType,
driverType: rbdType,
radosLsCmd: fmt.Sprintf("rados ls %s", rbdOptions(pool)),
radosLsCmd: "rados ls " + rbdOptions(pool),
radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.snap.", rbdOptions(pool)),
radosLsKeysCmd: fmt.Sprintf("rados listomapkeys csi.snaps.default %s", rbdOptions(pool)),
radosLsKeysCmd: "rados listomapkeys csi.snaps.default " + rbdOptions(pool),
radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.snaps.default %s | wc -l", rbdOptions(pool)),
},
}
@ -716,7 +716,7 @@ func checkDataPersist(pvcPath, appPath string, f *framework.Framework) error {
if err != nil {
return err
}
persistData, stdErr, err := execCommandInPod(f, fmt.Sprintf("cat %s", filePath), app.Namespace, &opt)
persistData, stdErr, err := execCommandInPod(f, "cat "+filePath, app.Namespace, &opt)
if err != nil {
return err
}
@ -793,7 +793,7 @@ func checkMountOptions(pvcPath, appPath string, f *framework.Framework, mountFla
LabelSelector: "app=validate-mount-opt",
}
cmd := fmt.Sprintf("mount |grep %s", app.Spec.Containers[0].VolumeMounts[0].MountPath)
cmd := "mount |grep " + app.Spec.Containers[0].VolumeMounts[0].MountPath
data, stdErr, err := execCommandInPod(f, cmd, app.Namespace, &opt)
if err != nil {
return err
@ -1545,7 +1545,7 @@ func validateController(
// If fetching the ServerVersion of the Kubernetes cluster fails, the calling
// test case is marked as `FAILED` and gets aborted.
//
//nolint:deadcode,unused // Unused code will be used in future.
//nolint:unused // Unused code will be used in future.
func k8sVersionGreaterEquals(c kubernetes.Interface, major, minor int) bool {
v, err := c.Discovery().ServerVersion()
if err != nil {
@ -1555,8 +1555,8 @@ func k8sVersionGreaterEquals(c kubernetes.Interface, major, minor int) bool {
// return value.
}
maj := fmt.Sprintf("%d", major)
min := fmt.Sprintf("%d", minor)
maj := strconv.Itoa(major)
min := strconv.Itoa(minor)
return (v.Major > maj) || (v.Major == maj && v.Minor >= min)
}

12
go.mod
View File

@ -27,9 +27,9 @@ require (
github.com/pkg/xattr v0.4.9
github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.22.0
golang.org/x/sys v0.18.0
golang.org/x/crypto v0.22.0
golang.org/x/net v0.24.0
golang.org/x/sys v0.19.0
google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0
//
@ -44,7 +44,7 @@ require (
k8s.io/mount-utils v0.29.3
k8s.io/pod-security-admission v0.29.3
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
sigs.k8s.io/controller-runtime v0.17.2
sigs.k8s.io/controller-runtime v0.17.3
)
require (
@ -163,7 +163,7 @@ require (
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/oauth2 v0.16.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/term v0.18.0 // indirect
golang.org/x/term v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.16.1 // indirect
@ -176,7 +176,7 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apiextensions-apiserver v0.29.0 // indirect
k8s.io/apiextensions-apiserver v0.29.2 // indirect
k8s.io/apiserver v0.29.3 // indirect
k8s.io/component-base v0.29.3 // indirect
k8s.io/component-helpers v0.29.3 // indirect

20
go.sum
View File

@ -1761,8 +1761,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1908,8 +1908,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -2078,8 +2078,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -2098,8 +2098,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -2702,8 +2702,8 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0 h1:TgtAeesdhpm2SGwkQasmbeqDo8th5wOBA5h/AjTKA4I=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0/go.mod h1:VHVDI/KrK4fjnV61bE2g3sA7tiETLn8sooImelsCx3Y=
sigs.k8s.io/controller-runtime v0.2.2/go.mod h1:9dyohw3ZtoXQuV1e766PHUn+cmrRCIcBh6XIMFNMZ+I=
sigs.k8s.io/controller-runtime v0.17.2 h1:FwHwD1CTUemg0pW2otk7/U5/i5m2ymzvOXdbeGOUvw0=
sigs.k8s.io/controller-runtime v0.17.2/go.mod h1:+MngTvIQQQhfXtwfdGw/UOQ/aIaqsYywfCINOtwMO/s=
sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk=
sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E=

View File

@ -183,13 +183,13 @@ func (cs *ControllerServer) checkContentSource(
req *csi.CreateVolumeRequest,
cr *util.Credentials,
) (*store.VolumeOptions, *store.VolumeIdentifier, *store.SnapshotIdentifier, error) {
if req.VolumeContentSource == nil {
if req.GetVolumeContentSource() == nil {
return nil, nil, nil, nil
}
volumeSource := req.VolumeContentSource
switch volumeSource.Type.(type) {
volumeSource := req.GetVolumeContentSource()
switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot:
snapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId()
snapshotID := req.GetVolumeContentSource().GetSnapshot().GetSnapshotId()
volOpt, _, sid, err := store.NewSnapshotOptionsFromID(ctx, snapshotID, cr,
req.GetSecrets(), cs.ClusterName, cs.SetMetadata)
if err != nil {
@ -203,9 +203,9 @@ func (cs *ControllerServer) checkContentSource(
return volOpt, nil, sid, nil
case *csi.VolumeContentSource_Volume:
// Find the volume using the provided VolumeID
volID := req.VolumeContentSource.GetVolume().GetVolumeId()
volID := req.GetVolumeContentSource().GetVolume().GetVolumeId()
parentVol, pvID, err := store.NewVolumeOptionsFromVolID(ctx,
volID, nil, req.Secrets, cs.ClusterName, cs.SetMetadata)
volID, nil, req.GetSecrets(), cs.ClusterName, cs.SetMetadata)
if err != nil {
if !errors.Is(err, cerrors.ErrVolumeNotFound) {
return nil, nil, nil, status.Error(codes.NotFound, err.Error())
@ -342,7 +342,7 @@ func (cs *ControllerServer) CreateVolume(
// As we are trying to create RWX volume from backing snapshot, we need to
// retrieve the snapshot details from the backing snapshot and create a
// subvolume clone from the snapshot.
if parentVol != nil && parentVol.BackingSnapshot && !store.IsVolumeCreateRO(req.VolumeCapabilities) {
if parentVol != nil && parentVol.BackingSnapshot && !store.IsVolumeCreateRO(req.GetVolumeCapabilities()) {
// unset pvID as we dont have real subvolume for the parent volumeID as its a backing snapshot
pvID = nil
parentVol, _, sID, err = store.NewSnapshotOptionsFromID(ctx, parentVol.BackingSnapshotID, cr,
@ -674,7 +674,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
req *csi.ValidateVolumeCapabilitiesRequest,
) (*csi.ValidateVolumeCapabilitiesResponse, error) {
// Cephfs doesn't support Block volume
for _, capability := range req.VolumeCapabilities {
for _, capability := range req.GetVolumeCapabilities() {
if capability.GetBlock() != nil {
return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil
}
@ -682,7 +682,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeCapabilities: req.VolumeCapabilities,
VolumeCapabilities: req.GetVolumeCapabilities(),
},
}, nil
}
@ -970,10 +970,10 @@ func (cs *ControllerServer) validateSnapshotReq(ctx context.Context, req *csi.Cr
}
// Check sanity of request Snapshot Name, Source Volume Id
if req.Name == "" {
if req.GetName() == "" {
return status.Error(codes.NotFound, "snapshot Name cannot be empty")
}
if req.SourceVolumeId == "" {
if req.GetSourceVolumeId() == "" {
return status.Error(codes.NotFound, "source Volume ID cannot be empty")
}

View File

@ -17,13 +17,12 @@ limitations under the License.
package core
import (
"errors"
"testing"
cerrors "github.com/ceph/ceph-csi/internal/cephfs/errors"
fsa "github.com/ceph/go-ceph/cephfs/admin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCloneStateToError(t *testing.T) {
@ -36,6 +35,6 @@ func TestCloneStateToError(t *testing.T) {
errorState[cephFSCloneState{fsa.CloneFailed, "", ""}] = cerrors.ErrCloneFailed
for state, err := range errorState {
assert.True(t, errors.Is(state.ToError(), err))
require.ErrorIs(t, state.ToError(), err)
}
}

View File

@ -29,11 +29,11 @@ import (
// that interacts with CephFS filesystem API's.
type FileSystem interface {
// GetFscID returns the ID of the filesystem with the given name.
GetFscID(context.Context, string) (int64, error)
GetFscID(ctx context.Context, fsName string) (int64, error)
// GetMetadataPool returns the metadata pool name of the filesystem with the given name.
GetMetadataPool(context.Context, string) (string, error)
GetMetadataPool(ctx context.Context, fsName string) (string, error)
// GetFsName returns the name of the filesystem with the given ID.
GetFsName(context.Context, int64) (string, error)
GetFsName(ctx context.Context, fsID int64) (string, error)
}
// fileSystem is the implementation of FileSystem interface.

View File

@ -20,7 +20,6 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ceph/ceph-csi/internal/util"
@ -44,7 +43,7 @@ func TestSetupCSIAddonsServer(t *testing.T) {
// verify the socket file has been created
_, err = os.Stat(tmpDir + "/csi-addons.sock")
assert.NoError(t, err)
require.NoError(t, err)
// stop the gRPC server
drv.cas.Stop()

View File

@ -180,7 +180,7 @@ func (cs *ControllerServer) CreateVolumeGroupSnapshot(
for _, r := range *resp {
r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.Snapshot)
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.GetSnapshot())
}
return response, nil
@ -293,7 +293,7 @@ func (cs *ControllerServer) releaseQuiesceAndGetVolumeGroupSnapshotResponse(
for _, r := range snapshotResponses {
r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.Snapshot)
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.GetSnapshot())
}
return response, nil
@ -703,7 +703,7 @@ func (cs *ControllerServer) DeleteVolumeGroupSnapshot(ctx context.Context,
return nil, err
}
groupSnapshotID := req.GroupSnapshotId
groupSnapshotID := req.GetGroupSnapshotId()
// Existence and conflict checks
if acquired := cs.VolumeGroupLocks.TryAcquire(groupSnapshotID); !acquired {
log.ErrorLog(ctx, util.VolumeOperationAlreadyExistsFmt, groupSnapshotID)
@ -718,7 +718,7 @@ func (cs *ControllerServer) DeleteVolumeGroupSnapshot(ctx context.Context,
}
defer cr.DeleteCredentials()
vgo, vgsi, err := store.NewVolumeGroupOptionsFromID(ctx, req.GroupSnapshotId, cr)
vgo, vgsi, err := store.NewVolumeGroupOptionsFromID(ctx, req.GetGroupSnapshotId(), cr)
if err != nil {
log.ErrorLog(ctx, "failed to get volume group options: %v", err)
err = extractDeleteVolumeGroupError(err)

View File

@ -81,7 +81,7 @@ func (m *kernelMounter) mountKernel(
optionsStr := fmt.Sprintf("name=%s,secretfile=%s", cr.ID, cr.KeyFile)
mdsNamespace := ""
if volOptions.FsName != "" {
mdsNamespace = fmt.Sprintf("mds_namespace=%s", volOptions.FsName)
mdsNamespace = "mds_namespace=" + volOptions.FsName
}
optionsStr = util.MountOptionsAdd(optionsStr, mdsNamespace, volOptions.KernelMountOptions, netDev)

View File

@ -19,7 +19,7 @@ package mounter
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFilesystemSupported(t *testing.T) {
@ -31,8 +31,8 @@ func TestFilesystemSupported(t *testing.T) {
// "proc" is always a supported filesystem, we detect supported
// filesystems by reading from it
assert.True(t, filesystemSupported("proc"))
require.True(t, filesystemSupported("proc"))
// "nonefs" is a made-up name, and does not exist
assert.False(t, filesystemSupported("nonefs"))
require.False(t, filesystemSupported("nonefs"))
}

View File

@ -110,8 +110,7 @@ func (ns *NodeServer) getVolumeOptions(
func validateSnapshotBackedVolCapability(volCap *csi.VolumeCapability) error {
// Snapshot-backed volumes may be used with read-only volume access modes only.
mode := volCap.AccessMode.Mode
mode := volCap.GetAccessMode().GetMode()
if mode != csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY &&
mode != csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
return status.Error(codes.InvalidArgument,
@ -352,7 +351,6 @@ func (ns *NodeServer) mount(
true,
[]string{"bind", "_netdev"},
)
if err != nil {
log.ErrorLog(ctx,
"failed to bind mount snapshot root %s: %v", absoluteSnapshotRoot, err)
@ -813,9 +811,9 @@ func (ns *NodeServer) setMountOptions(
}
const readOnly = "ro"
if volCap.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
volCap.AccessMode.Mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
mode := volCap.GetAccessMode().GetMode()
if mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
switch mnt.(type) {
case *mounter.FuseMounter:
if !csicommon.MountOptionContains(strings.Split(volOptions.FuseMountOptions, ","), readOnly) {

View File

@ -18,7 +18,6 @@ package store
import (
"context"
"fmt"
fsutil "github.com/ceph/ceph-csi/internal/cephfs/util"
"github.com/ceph/ceph-csi/internal/util/log"
@ -28,7 +27,7 @@ import (
)
func fmtBackingSnapshotReftrackerName(backingSnapID string) string {
return fmt.Sprintf("rt-backingsnapshot-%s", backingSnapID)
return "rt-backingsnapshot-" + backingSnapID
}
func AddSnapshotBackedVolumeRef(

View File

@ -168,7 +168,7 @@ func extractMounter(dest *string, options map[string]string) error {
func GetClusterInformation(options map[string]string) (*cephcsi.ClusterInfo, error) {
clusterID, ok := options["clusterID"]
if !ok {
err := fmt.Errorf("clusterID must be set")
err := errors.New("clusterID must be set")
return nil, err
}
@ -344,15 +344,15 @@ func NewVolumeOptions(
// IsShallowVolumeSupported returns true only for ReadOnly volume requests
// with datasource as snapshot.
func IsShallowVolumeSupported(req *csi.CreateVolumeRequest) bool {
isRO := IsVolumeCreateRO(req.VolumeCapabilities)
isRO := IsVolumeCreateRO(req.GetVolumeCapabilities())
return isRO && (req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetSnapshot() != nil)
}
func IsVolumeCreateRO(caps []*csi.VolumeCapability) bool {
for _, cap := range caps {
if cap.AccessMode != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want
if cap.GetAccessMode() != nil {
switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY,
csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY:
return true
@ -612,7 +612,7 @@ func NewVolumeOptionsFromMonitorList(
// check if there are mon values in secret and if so override option retrieved monitors from
// monitors in the secret
mon, err := util.GetMonValFromSecret(secrets)
if err == nil && len(mon) > 0 {
if err == nil && mon != "" {
opts.Monitors = mon
}

View File

@ -54,11 +54,11 @@ func (cs *ControllerServer) validateCreateVolumeRequest(req *csi.CreateVolumeReq
return err
}
if req.VolumeContentSource != nil {
volumeSource := req.VolumeContentSource
switch volumeSource.Type.(type) {
if req.GetVolumeContentSource() != nil {
volumeSource := req.GetVolumeContentSource()
switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot:
snapshot := req.VolumeContentSource.GetSnapshot()
snapshot := req.GetVolumeContentSource().GetSnapshot()
// CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect.
if snapshot == nil {
return status.Error(codes.NotFound, "volume Snapshot cannot be empty")
@ -68,7 +68,7 @@ func (cs *ControllerServer) validateCreateVolumeRequest(req *csi.CreateVolumeReq
}
case *csi.VolumeContentSource_Volume:
// CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect.
vol := req.VolumeContentSource.GetVolume()
vol := req.GetVolumeContentSource().GetVolume()
if vol == nil {
return status.Error(codes.NotFound, "volume cannot be empty")
}

View File

@ -31,7 +31,7 @@ import (
// The New controllers which gets added, as to implement Add function to get
// started by the manager.
type Manager interface {
Add(manager.Manager, Config) error
Add(mgr manager.Manager, cfg Config) error
}
// Config holds the drivername and namespace name.

View File

@ -66,7 +66,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
ctx context.Context,
req *fence.FenceClusterNetworkRequest,
) (*fence.FenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters)
err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
@ -77,7 +77,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
}
defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters())
nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
@ -95,7 +95,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
ctx context.Context,
req *fence.UnfenceClusterNetworkRequest,
) (*fence.UnfenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters)
err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
@ -106,7 +106,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
}
defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters())
nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

View File

@ -21,7 +21,7 @@ import (
"testing"
"github.com/csi-addons/spec/lib/go/fence"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork()
@ -39,7 +39,7 @@ func TestFenceClusterNetwork(t *testing.T) {
}
_, err := controller.FenceClusterNetwork(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}
// TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork()
@ -55,5 +55,5 @@ func TestUnfenceClusterNetwork(t *testing.T) {
Cidrs: nil,
}
_, err := controller.UnfenceClusterNetwork(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}

View File

@ -348,7 +348,7 @@ type Cidrs []*fence.CIDR
func GetCIDR(cidrs Cidrs) ([]string, error) {
var cidrList []string
for _, cidr := range cidrs {
cidrList = append(cidrList, cidr.Cidr)
cidrList = append(cidrList, cidr.GetCidr())
}
if len(cidrList) < 1 {
return nil, errors.New("the CIDR cannot be empty")

View File

@ -19,7 +19,7 @@ package networkfence
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetIPRange(t *testing.T) {
@ -47,10 +47,10 @@ func TestGetIPRange(t *testing.T) {
t.Run(ts.cidr, func(t *testing.T) {
t.Parallel()
got, err := getIPRange(ts.cidr)
assert.NoError(t, err)
require.NoError(t, err)
// validate if number of IPs in the range is same as expected, if not, fail.
assert.ElementsMatch(t, ts.expectedIPs, got)
require.ElementsMatch(t, ts.expectedIPs, got)
})
}
}

View File

@ -62,7 +62,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
ctx context.Context,
req *fence.FenceClusterNetworkRequest,
) (*fence.FenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters)
err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
@ -73,7 +73,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
}
defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters())
nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
@ -91,7 +91,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
ctx context.Context,
req *fence.UnfenceClusterNetworkRequest,
) (*fence.UnfenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters)
err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
@ -102,7 +102,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
}
defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters())
nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

View File

@ -17,9 +17,8 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/csi-addons/spec/lib/go/fence"
"github.com/stretchr/testify/require"
)
// TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork()
@ -37,7 +36,7 @@ func TestFenceClusterNetwork(t *testing.T) {
}
_, err := controller.FenceClusterNetwork(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}
// TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork()
@ -53,5 +52,5 @@ func TestUnfenceClusterNetwork(t *testing.T) {
Cidrs: nil,
}
_, err := controller.UnfenceClusterNetwork(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}

View File

@ -20,9 +20,8 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
rs "github.com/csi-addons/spec/lib/go/reclaimspace"
"github.com/stretchr/testify/require"
)
// TestControllerReclaimSpace is a minimal test for the
@ -39,7 +38,7 @@ func TestControllerReclaimSpace(t *testing.T) {
}
_, err := controller.ControllerReclaimSpace(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}
// TestNodeReclaimSpace is a minimal test for the NodeReclaimSpace() procedure.
@ -58,5 +57,5 @@ func TestNodeReclaimSpace(t *testing.T) {
}
_, err := node.NodeReclaimSpace(context.TODO(), req)
assert.Error(t, err)
require.Error(t, err)
}

View File

@ -709,7 +709,7 @@ func (rs *ReplicationServer) ResyncVolume(ctx context.Context,
return nil, status.Errorf(codes.Internal, "failed to parse image creation time: %s", sErr.Error())
}
log.DebugLog(ctx, "image %s, savedImageTime=%v, currentImageTime=%v", rbdVol, st, creationTime.AsTime())
if req.Force && st.Equal(creationTime.AsTime()) {
if req.GetForce() && st.Equal(creationTime.AsTime()) {
err = rbdVol.ResyncVol(localStatus)
if err != nil {
return nil, getGRPCError(err)
@ -738,7 +738,7 @@ func (rs *ReplicationServer) ResyncVolume(ctx context.Context,
// timestampToString converts the time.Time object to string.
func timestampToString(st *timestamppb.Timestamp) string {
return fmt.Sprintf("seconds:%d nanos:%d", st.Seconds, st.Nanos)
return fmt.Sprintf("seconds:%d nanos:%d", st.GetSeconds(), st.GetNanos())
}
// timestampFromString parses the timestamp string and returns the time.Time
@ -989,7 +989,7 @@ func checkVolumeResyncStatus(ctx context.Context, localStatus librbd.SiteMirrorI
if err != nil {
return fmt.Errorf("failed to get last sync info: %w", err)
}
if resp.LastSyncTime == nil {
if resp.GetLastSyncTime() == nil {
return errors.New("last sync time is nil")
}

View File

@ -30,7 +30,7 @@ import (
librbd "github.com/ceph/go-ceph/rbd"
"github.com/ceph/go-ceph/rbd/admin"
"github.com/csi-addons/spec/lib/go/replication"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
@ -511,19 +511,29 @@ func TestValidateLastSyncInfo(t *testing.T) {
tt.expectedErr, err)
}
if teststruct != nil {
if teststruct.LastSyncTime.GetSeconds() != tt.info.LastSyncTime.GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncTime, tt.info.LastSyncTime)
if teststruct.GetLastSyncTime().GetSeconds() != tt.info.GetLastSyncTime().GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncTime(),
tt.info.GetLastSyncTime())
}
if tt.info.LastSyncDuration == nil && teststruct.LastSyncDuration != nil {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncDuration,
tt.info.LastSyncDuration)
if tt.info.GetLastSyncDuration() == nil && teststruct.GetLastSyncDuration() != nil {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncDuration(),
tt.info.GetLastSyncDuration())
}
if teststruct.LastSyncDuration.GetSeconds() != tt.info.LastSyncDuration.GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncDuration,
tt.info.LastSyncDuration)
if teststruct.GetLastSyncDuration().GetSeconds() != tt.info.GetLastSyncDuration().GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncDuration(),
tt.info.GetLastSyncDuration())
}
if teststruct.LastSyncBytes != tt.info.LastSyncBytes {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncBytes, tt.info.LastSyncBytes)
if teststruct.GetLastSyncBytes() != tt.info.GetLastSyncBytes() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncBytes(),
tt.info.GetLastSyncBytes())
}
}
})
@ -594,7 +604,7 @@ func TestGetGRPCError(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := getGRPCError(tt.err)
assert.Equal(t, tt.expectedErr, result)
require.Equal(t, tt.expectedErr, result)
})
}
}

View File

@ -19,7 +19,6 @@ package server
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -39,7 +38,7 @@ func TestNewCSIAddonsServer(t *testing.T) {
cas, err := NewCSIAddonsServer("")
require.Error(t, err)
assert.Nil(t, cas)
require.Nil(t, cas)
})
t.Run("no UDS endpoint", func(t *testing.T) {
@ -47,6 +46,6 @@ func TestNewCSIAddonsServer(t *testing.T) {
cas, err := NewCSIAddonsServer("endpoint at /tmp/...")
require.Error(t, err)
assert.Nil(t, cas)
require.Nil(t, cas)
})
}

View File

@ -70,8 +70,6 @@ func NewCSIDriver(name, v, nodeID string) *CSIDriver {
// ValidateControllerServiceRequest validates the controller
// plugin capabilities.
//
//nolint:interfacer // c can be of type fmt.Stringer, but that does not make the API clearer
func (d *CSIDriver) ValidateControllerServiceRequest(c csi.ControllerServiceCapability_RPC_Type) error {
if c == csi.ControllerServiceCapability_RPC_UNKNOWN {
return nil
@ -133,8 +131,6 @@ func (d *CSIDriver) AddGroupControllerServiceCapabilities(cl []csi.GroupControll
// ValidateGroupControllerServiceRequest validates the group controller
// plugin capabilities.
//
//nolint:interfacer // c can be of type fmt.Stringer, but that does not make the API clearer
func (d *CSIDriver) ValidateGroupControllerServiceRequest(c csi.GroupControllerServiceCapability_RPC_Type) error {
if c == csi.GroupControllerServiceCapability_RPC_UNKNOWN {
return nil

View File

@ -86,7 +86,7 @@ func ConstructMountOptions(mountOptions []string, volCap *csi.VolumeCapability)
return false
}
for _, f := range m.MountFlags {
for _, f := range m.GetMountFlags() {
if !hasOption(mountOptions, f) {
mountOptions = append(mountOptions, f)
}

View File

@ -121,52 +121,52 @@ func getReqID(req interface{}) string {
reqID := ""
switch r := req.(type) {
case *csi.CreateVolumeRequest:
reqID = r.Name
reqID = r.GetName()
case *csi.DeleteVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.CreateSnapshotRequest:
reqID = r.Name
reqID = r.GetName()
case *csi.DeleteSnapshotRequest:
reqID = r.SnapshotId
reqID = r.GetSnapshotId()
case *csi.ControllerExpandVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.NodeStageVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.NodeUnstageVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.NodePublishVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.NodeUnpublishVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.NodeExpandVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *csi.CreateVolumeGroupSnapshotRequest:
reqID = r.Name
reqID = r.GetName()
case *csi.DeleteVolumeGroupSnapshotRequest:
reqID = r.GroupSnapshotId
reqID = r.GetGroupSnapshotId()
case *csi.GetVolumeGroupSnapshotRequest:
reqID = r.GroupSnapshotId
reqID = r.GetGroupSnapshotId()
// Replication
case *replication.EnableVolumeReplicationRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *replication.DisableVolumeReplicationRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *replication.PromoteVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *replication.DemoteVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *replication.ResyncVolumeRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
case *replication.GetVolumeReplicationInfoRequest:
reqID = r.VolumeId
reqID = r.GetVolumeId()
}
return reqID
@ -353,9 +353,9 @@ func IsFileRWO(caps []*csi.VolumeCapability) bool {
// to preserve backward compatibility we allow RWO filemode, ideally SINGLE_NODE_WRITER check is good enough,
// however more granular level check could help us in future, so keeping it here as an additional measure.
for _, cap := range caps {
if cap.AccessMode != nil {
if cap.GetAccessMode() != nil {
if cap.GetMount() != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want
switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER:
@ -372,8 +372,8 @@ func IsFileRWO(caps []*csi.VolumeCapability) bool {
// or block mode.
func IsReaderOnly(caps []*csi.VolumeCapability) bool {
for _, cap := range caps {
if cap.AccessMode != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want
if cap.GetAccessMode() != nil {
switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY,
csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY:
return true
@ -397,8 +397,8 @@ func IsBlockMultiWriter(caps []*csi.VolumeCapability) (bool, bool) {
var block bool
for _, cap := range caps {
if cap.AccessMode != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want
if cap.GetAccessMode() != nil {
switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER:
multiWriter = true

View File

@ -25,7 +25,6 @@ import (
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/csi-addons/spec/lib/go/replication"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
mount "k8s.io/mount-utils"
)
@ -127,12 +126,12 @@ func TestFilesystemNodeGetVolumeStats(t *testing.T) {
}
require.NoError(t, err)
assert.NotEqual(t, len(stats.Usage), 0)
for _, usage := range stats.Usage {
assert.NotEqual(t, usage.Available, -1)
assert.NotEqual(t, usage.Total, -1)
assert.NotEqual(t, usage.Used, -1)
assert.NotEqual(t, usage.Unit, 0)
require.NotEmpty(t, stats.GetUsage())
for _, usage := range stats.GetUsage() {
require.NotEqual(t, -1, usage.GetAvailable())
require.NotEqual(t, -1, usage.GetTotal())
require.NotEqual(t, -1, usage.GetUsed())
require.NotEqual(t, 0, usage.GetUnit())
}
// tests done, no need to retry again
@ -143,9 +142,9 @@ func TestFilesystemNodeGetVolumeStats(t *testing.T) {
func TestRequirePositive(t *testing.T) {
t.Parallel()
assert.Equal(t, requirePositive(0), int64(0))
assert.Equal(t, requirePositive(-1), int64(0))
assert.Equal(t, requirePositive(1), int64(1))
require.Equal(t, int64(0), requirePositive(0))
require.Equal(t, int64(0), requirePositive(-1))
require.Equal(t, int64(1), requirePositive(1))
}
func TestIsBlockMultiNode(t *testing.T) {
@ -204,8 +203,8 @@ func TestIsBlockMultiNode(t *testing.T) {
for _, test := range tests {
isBlock, isMultiNode := IsBlockMultiNode(test.caps)
assert.Equal(t, isBlock, test.isBlock, test.name)
assert.Equal(t, isMultiNode, test.isMultiNode, test.name)
require.Equal(t, isBlock, test.isBlock, test.name)
require.Equal(t, isMultiNode, test.isMultiNode, test.name)
}
}

View File

@ -561,7 +561,7 @@ func (conn *Connection) ReserveName(ctx context.Context,
imagePool string, imagePoolID int64,
reqName, namePrefix, parentName, kmsConf, volUUID, owner,
backingSnapshotID string,
encryptionType util.EncryptionType, //nolint:interfacer // prefer util.EncryptionType over fmt.Stringer
encryptionType util.EncryptionType,
) (string, string, error) {
// TODO: Take in-arg as ImageAttributes?
var (

View File

@ -19,11 +19,11 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAWSMetadataKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeAWSMetadata]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -19,11 +19,11 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAWSSTSMetadataKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeAWSSTSMetadata]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -19,11 +19,11 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAzureKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeAzure]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -19,11 +19,11 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKeyProtectMetadataKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeKeyProtectMetadata]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -19,11 +19,11 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKMIPKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeKMIP]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -19,7 +19,7 @@ package kms
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func noinitKMS(args ProviderInitArgs) (EncryptionKMS, error) {
@ -47,9 +47,9 @@ func TestRegisterProvider(t *testing.T) {
for _, test := range tests {
provider := test.provider
if test.panics {
assert.Panics(t, func() { RegisterProvider(provider) })
require.Panics(t, func() { RegisterProvider(provider) })
} else {
assert.True(t, RegisterProvider(provider))
require.True(t, RegisterProvider(provider))
}
}
}

View File

@ -20,7 +20,7 @@ import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetConfigInt(t *testing.T) {
@ -81,7 +81,7 @@ func TestSetConfigInt(t *testing.T) {
t.Errorf("setConfigInt() error = %v, wantErr %v", err, currentTT.err)
}
if err != nil {
assert.NotEqual(t, currentTT.value, currentTT.args.option)
require.NotEqual(t, currentTT.value, currentTT.args.option)
}
})
}

View File

@ -20,7 +20,6 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -32,24 +31,24 @@ func TestNewSecretsKMS(t *testing.T) {
kms, err := newSecretsKMS(ProviderInitArgs{
Secrets: secrets,
})
assert.Error(t, err)
assert.Nil(t, kms)
require.Error(t, err)
require.Nil(t, kms)
// set a passphrase and it should pass
secrets[encryptionPassphraseKey] = "plaintext encryption key"
kms, err = newSecretsKMS(ProviderInitArgs{
Secrets: secrets,
})
assert.NotNil(t, kms)
assert.NoError(t, err)
require.NotNil(t, kms)
require.NoError(t, err)
}
func TestGenerateNonce(t *testing.T) {
t.Parallel()
size := 64
nonce, err := generateNonce(size)
assert.Equal(t, size, len(nonce))
assert.NoError(t, err)
require.Len(t, nonce, size)
require.NoError(t, err)
}
func TestGenerateCipher(t *testing.T) {
@ -59,8 +58,8 @@ func TestGenerateCipher(t *testing.T) {
salt := "unique-id-for-the-volume"
aead, err := generateCipher(passphrase, salt)
assert.NoError(t, err)
assert.NotNil(t, aead)
require.NoError(t, err)
require.NotNil(t, aead)
}
func TestInitSecretsMetadataKMS(t *testing.T) {
@ -73,16 +72,16 @@ func TestInitSecretsMetadataKMS(t *testing.T) {
// passphrase it not set, init should fail
kms, err := initSecretsMetadataKMS(args)
assert.Error(t, err)
assert.Nil(t, kms)
require.Error(t, err)
require.Nil(t, kms)
// set a passphrase to get a working KMS
args.Secrets[encryptionPassphraseKey] = "my-passphrase-from-kubernetes"
kms, err = initSecretsMetadataKMS(args)
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, kms)
assert.Equal(t, DEKStoreMetadata, kms.RequiresDEKStore())
require.Equal(t, DEKStoreMetadata, kms.RequiresDEKStore())
}
func TestWorkflowSecretsMetadataKMS(t *testing.T) {
@ -98,7 +97,7 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
volumeID := "csi-vol-1b00f5f8-b1c1-11e9-8421-9243c1f659f0"
kms, err := initSecretsMetadataKMS(args)
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, kms)
// plainDEK is the (LUKS) passphrase for the volume
@ -107,25 +106,25 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
ctx := context.TODO()
encryptedDEK, err := kms.EncryptDEK(ctx, volumeID, plainDEK)
assert.NoError(t, err)
assert.NotEqual(t, "", encryptedDEK)
assert.NotEqual(t, plainDEK, encryptedDEK)
require.NoError(t, err)
require.NotEqual(t, "", encryptedDEK)
require.NotEqual(t, plainDEK, encryptedDEK)
// with an incorrect volumeID, decrypting should fail
decryptedDEK, err := kms.DecryptDEK(ctx, "incorrect-volumeID", encryptedDEK)
assert.Error(t, err)
assert.Equal(t, "", decryptedDEK)
assert.NotEqual(t, plainDEK, decryptedDEK)
require.Error(t, err)
require.Equal(t, "", decryptedDEK)
require.NotEqual(t, plainDEK, decryptedDEK)
// with the right volumeID, decrypting should return the plainDEK
decryptedDEK, err = kms.DecryptDEK(ctx, volumeID, encryptedDEK)
assert.NoError(t, err)
assert.NotEqual(t, "", decryptedDEK)
assert.Equal(t, plainDEK, decryptedDEK)
require.NoError(t, err)
require.NotEqual(t, "", decryptedDEK)
require.Equal(t, plainDEK, decryptedDEK)
}
func TestSecretsMetadataKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeSecretsMetadata]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -20,13 +20,13 @@ import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVaultTenantSAKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeVaultTenantSA]
assert.True(t, ok)
require.True(t, ok)
}
func TestTenantSAParseConfig(t *testing.T) {

View File

@ -22,7 +22,6 @@ import (
"testing"
loss "github.com/libopenstorage/secrets"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -113,8 +112,8 @@ func TestDefaultVaultDestroyKeys(t *testing.T) {
require.NoError(t, err)
keyContext := vc.getDeleteKeyContext()
destroySecret, ok := keyContext[loss.DestroySecret]
assert.NotEqual(t, destroySecret, "")
assert.True(t, ok)
require.NotEqual(t, "", destroySecret)
require.True(t, ok)
// setting vaultDestroyKeys to !true should remove the loss.DestroySecret entry
config["vaultDestroyKeys"] = "false"
@ -122,11 +121,11 @@ func TestDefaultVaultDestroyKeys(t *testing.T) {
require.NoError(t, err)
keyContext = vc.getDeleteKeyContext()
_, ok = keyContext[loss.DestroySecret]
assert.False(t, ok)
require.False(t, ok)
}
func TestVaultKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeVault]
assert.True(t, ok)
require.True(t, ok)
}

View File

@ -25,7 +25,6 @@ import (
"github.com/hashicorp/vault/api"
loss "github.com/libopenstorage/secrets"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -205,18 +204,18 @@ func TestTransformConfig(t *testing.T) {
config, err := transformConfig(cm)
require.NoError(t, err)
assert.Equal(t, config["encryptionKMSType"], cm["KMS_PROVIDER"])
assert.Equal(t, config["vaultAddress"], cm["VAULT_ADDR"])
assert.Equal(t, config["vaultBackend"], cm["VAULT_BACKEND"])
assert.Equal(t, config["vaultBackendPath"], cm["VAULT_BACKEND_PATH"])
assert.Equal(t, config["vaultDestroyKeys"], cm["VAULT_DESTROY_KEYS"])
assert.Equal(t, config["vaultCAFromSecret"], cm["VAULT_CACERT"])
assert.Equal(t, config["vaultTLSServerName"], cm["VAULT_TLS_SERVER_NAME"])
assert.Equal(t, config["vaultClientCertFromSecret"], cm["VAULT_CLIENT_CERT"])
assert.Equal(t, config["vaultClientCertKeyFromSecret"], cm["VAULT_CLIENT_KEY"])
assert.Equal(t, config["vaultAuthNamespace"], cm["VAULT_AUTH_NAMESPACE"])
assert.Equal(t, config["vaultNamespace"], cm["VAULT_NAMESPACE"])
assert.Equal(t, config["vaultCAVerify"], "false")
require.Equal(t, cm["KMS_PROVIDER"], config["encryptionKMSType"])
require.Equal(t, cm["VAULT_ADDR"], config["vaultAddress"])
require.Equal(t, cm["VAULT_BACKEND"], config["vaultBackend"])
require.Equal(t, cm["VAULT_BACKEND_PATH"], config["vaultBackendPath"])
require.Equal(t, cm["VAULT_DESTROY_KEYS"], config["vaultDestroyKeys"])
require.Equal(t, cm["VAULT_CACERT"], config["vaultCAFromSecret"])
require.Equal(t, cm["VAULT_TLS_SERVER_NAME"], config["vaultTLSServerName"])
require.Equal(t, cm["VAULT_CLIENT_CERT"], config["vaultClientCertFromSecret"])
require.Equal(t, cm["VAULT_CLIENT_KEY"], config["vaultClientCertKeyFromSecret"])
require.Equal(t, cm["VAULT_AUTH_NAMESPACE"], config["vaultAuthNamespace"])
require.Equal(t, cm["VAULT_NAMESPACE"], config["vaultNamespace"])
require.Equal(t, "false", config["vaultCAVerify"])
}
func TestTransformConfigDefaults(t *testing.T) {
@ -226,15 +225,15 @@ func TestTransformConfigDefaults(t *testing.T) {
config, err := transformConfig(cm)
require.NoError(t, err)
assert.Equal(t, config["encryptionKMSType"], cm["KMS_PROVIDER"])
assert.Equal(t, config["vaultDestroyKeys"], vaultDefaultDestroyKeys)
assert.Equal(t, config["vaultCAVerify"], strconv.FormatBool(vaultDefaultCAVerify))
require.Equal(t, cm["KMS_PROVIDER"], config["encryptionKMSType"])
require.Equal(t, vaultDefaultDestroyKeys, config["vaultDestroyKeys"])
require.Equal(t, strconv.FormatBool(vaultDefaultCAVerify), config["vaultCAVerify"])
}
func TestVaultTokensKMSRegistered(t *testing.T) {
t.Parallel()
_, ok := kmsManager.providers[kmsTypeVaultTokens]
assert.True(t, ok)
require.True(t, ok)
}
func TestSetTenantAuthNamespace(t *testing.T) {
@ -259,7 +258,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
kms.setTenantAuthNamespace(config)
assert.Equal(tt, vaultNamespace, config["vaultAuthNamespace"])
require.Equal(tt, vaultNamespace, config["vaultAuthNamespace"])
})
t.Run("inherit vaultAuthNamespace", func(tt *testing.T) {
@ -283,7 +282,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
// when inheriting from the global config, the config of the
// tenant should not have vaultAuthNamespace configured
assert.Equal(tt, nil, config["vaultAuthNamespace"])
require.Nil(tt, config["vaultAuthNamespace"])
})
t.Run("unset vaultAuthNamespace", func(tt *testing.T) {
@ -306,7 +305,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
// global vaultAuthNamespace is not set, tenant
// vaultAuthNamespace will be configured as vaultNamespace by
// default
assert.Equal(tt, nil, config["vaultAuthNamespace"])
require.Nil(tt, config["vaultAuthNamespace"])
})
t.Run("no vaultNamespace", func(tt *testing.T) {
@ -326,6 +325,6 @@ func TestSetTenantAuthNamespace(t *testing.T) {
kms.setTenantAuthNamespace(config)
assert.Equal(tt, nil, config["vaultAuthNamespace"])
require.Nil(tt, config["vaultAuthNamespace"])
})
}

View File

@ -84,9 +84,9 @@ func (cs *Server) CreateVolume(
return nil, err
}
backend := res.Volume
backend := res.GetVolume()
log.DebugLog(ctx, "CephFS volume created: %s", backend.VolumeId)
log.DebugLog(ctx, "CephFS volume created: %s", backend.GetVolumeId())
secret := req.GetSecrets()
cr, err := util.NewAdminCredentials(secret)
@ -97,7 +97,7 @@ func (cs *Server) CreateVolume(
}
defer cr.DeleteCredentials()
nfsVolume, err := NewNFSVolume(ctx, backend.VolumeId)
nfsVolume, err := NewNFSVolume(ctx, backend.GetVolumeId())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}

View File

@ -127,12 +127,12 @@ func (nv *NFSVolume) CreateExport(backend *csi.Volume) error {
if !nv.connected {
return fmt.Errorf("can not created export for %q: %w", nv, ErrNotConnected)
}
fs := backend.VolumeContext["fsName"]
nfsCluster := backend.VolumeContext["nfsCluster"]
path := backend.VolumeContext["subvolumePath"]
secTypes := backend.VolumeContext["secTypes"]
clients := backend.VolumeContext["clients"]
vctx := backend.GetVolumeContext()
fs := vctx["fsName"]
nfsCluster := vctx["nfsCluster"]
path := vctx["subvolumePath"]
secTypes := vctx["secTypes"]
clients := vctx["clients"]
err := nv.setNFSCluster(nfsCluster)
if err != nil {

View File

@ -68,10 +68,10 @@ func (cs *ControllerServer) validateVolumeReq(ctx context.Context, req *csi.Crea
return err
}
// Check sanity of request Name, Volume Capabilities
if req.Name == "" {
if req.GetName() == "" {
return status.Error(codes.InvalidArgument, "volume Name cannot be empty")
}
if req.VolumeCapabilities == nil {
if req.GetVolumeCapabilities() == nil {
return status.Error(codes.InvalidArgument, "volume Capabilities cannot be empty")
}
options := req.GetParameters()
@ -105,7 +105,7 @@ func (cs *ControllerServer) validateVolumeReq(ctx context.Context, req *csi.Crea
return err
}
err = validateStriping(req.Parameters)
err = validateStriping(req.GetParameters())
if err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
@ -156,13 +156,13 @@ func (cs *ControllerServer) parseVolCreateRequest(
// below capability check indicates that we support both {SINGLE_NODE or MULTI_NODE} WRITERs and the `isMultiWriter`
// flag has been set accordingly.
isMultiWriter, isBlock := csicommon.IsBlockMultiWriter(req.VolumeCapabilities)
isMultiWriter, isBlock := csicommon.IsBlockMultiWriter(req.GetVolumeCapabilities())
// below return value has set, if it is RWO mode File PVC.
isRWOFile := csicommon.IsFileRWO(req.VolumeCapabilities)
isRWOFile := csicommon.IsFileRWO(req.GetVolumeCapabilities())
// below return value has set, if it is ReadOnly capability.
isROOnly := csicommon.IsReaderOnly(req.VolumeCapabilities)
isROOnly := csicommon.IsReaderOnly(req.GetVolumeCapabilities())
// We want to fail early if the user is trying to create a RWX on a non-block type device
if !isRWOFile && !isBlock && !isROOnly {
return nil, status.Error(
@ -782,13 +782,13 @@ func checkContentSource(
req *csi.CreateVolumeRequest,
cr *util.Credentials,
) (*rbdVolume, *rbdSnapshot, error) {
if req.VolumeContentSource == nil {
if req.GetVolumeContentSource() == nil {
return nil, nil, nil
}
volumeSource := req.VolumeContentSource
switch volumeSource.Type.(type) {
volumeSource := req.GetVolumeContentSource()
switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot:
snapshot := req.VolumeContentSource.GetSnapshot()
snapshot := req.GetVolumeContentSource().GetSnapshot()
if snapshot == nil {
return nil, nil, status.Error(codes.NotFound, "volume Snapshot cannot be empty")
}
@ -808,7 +808,7 @@ func checkContentSource(
return nil, rbdSnap, nil
case *csi.VolumeContentSource_Volume:
vol := req.VolumeContentSource.GetVolume()
vol := req.GetVolumeContentSource().GetVolume()
if vol == nil {
return nil, nil, status.Error(codes.NotFound, "volume cannot be empty")
}
@ -1066,11 +1066,11 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
if len(req.VolumeCapabilities) == 0 {
if len(req.GetVolumeCapabilities()) == 0 {
return nil, status.Error(codes.InvalidArgument, "empty volume capabilities in request")
}
for _, capability := range req.VolumeCapabilities {
for _, capability := range req.GetVolumeCapabilities() {
if capability.GetAccessMode().GetMode() != csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER {
return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil
}
@ -1078,7 +1078,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeCapabilities: req.VolumeCapabilities,
VolumeCapabilities: req.GetVolumeCapabilities(),
},
}, nil
}
@ -1297,10 +1297,10 @@ func (cs *ControllerServer) validateSnapshotReq(ctx context.Context, req *csi.Cr
}
// Check sanity of request Snapshot Name, Source Volume Id
if req.Name == "" {
if req.GetName() == "" {
return status.Error(codes.InvalidArgument, "snapshot Name cannot be empty")
}
if req.SourceVolumeId == "" {
if req.GetSourceVolumeId() == "" {
return status.Error(codes.InvalidArgument, "source Volume ID cannot be empty")
}

View File

@ -20,7 +20,6 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ceph/ceph-csi/internal/util"
@ -44,7 +43,7 @@ func TestSetupCSIAddonsServer(t *testing.T) {
// verify the socket file has been created
_, err = os.Stat(tmpDir + "/csi-addons.sock")
assert.NoError(t, err)
require.NoError(t, err)
// stop the gRPC server
drv.cas.Stop()

View File

@ -312,7 +312,7 @@ func (ri *rbdImage) initKMS(ctx context.Context, volOptions, credentials map[str
case util.EncryptionTypeFile:
err = ri.configureFileEncryption(ctx, kmsID, credentials)
case util.EncryptionTypeInvalid:
return fmt.Errorf("invalid encryption type")
return errors.New("invalid encryption type")
case util.EncryptionTypeNone:
return nil
}

View File

@ -163,7 +163,7 @@ func (ns *NodeServer) populateRbdVol(
isBlock := req.GetVolumeCapability().GetBlock() != nil
disableInUseChecks := false
// MULTI_NODE_MULTI_WRITER is supported by default for Block access type volumes
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER {
if req.GetVolumeCapability().GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER {
if !isBlock {
log.WarningLog(
ctx,
@ -400,7 +400,7 @@ func (ns *NodeServer) stageTransaction(
var err error
// Allow image to be mounted on multiple nodes if it is ROX
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY {
if req.GetVolumeCapability().GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY {
log.ExtendedLog(ctx, "setting disableInUseChecks on rbd volume to: %v", req.GetVolumeId)
volOptions.DisableInUseChecks = true
volOptions.readOnly = true
@ -777,8 +777,9 @@ func (ns *NodeServer) mountVolumeToStagePath(
isBlock := req.GetVolumeCapability().GetBlock() != nil
rOnly := "ro"
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
mode := req.GetVolumeCapability().GetAccessMode().GetMode()
if mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
if !csicommon.MountOptionContains(opt, rOnly) {
opt = append(opt, rOnly)
}

View File

@ -27,7 +27,7 @@ import (
"github.com/ceph/ceph-csi/internal/util"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetStagingPath(t *testing.T) {
@ -196,7 +196,7 @@ func TestNodeServer_appendReadAffinityMapOptions(t *testing.T) {
Mounter: currentTT.args.mounter,
}
rv.appendReadAffinityMapOptions(currentTT.args.readAffinityMapOptions)
assert.Equal(t, currentTT.want, rv.MapOptions)
require.Equal(t, currentTT.want, rv.MapOptions)
})
}
}
@ -310,10 +310,10 @@ func TestReadAffinity_GetReadAffinityMapOptions(t *testing.T) {
tmpConfPath, tc.clusterID, ns.CLIReadAffinityOptions, nodeLabels,
)
if err != nil {
assert.Fail(t, err.Error())
require.Fail(t, err.Error())
}
assert.Equal(t, tc.want, readAffinityMapOptions)
require.Equal(t, tc.want, readAffinityMapOptions)
})
}
}

View File

@ -229,7 +229,7 @@ func waitForPath(ctx context.Context, pool, namespace, image string, maxRetries
func SetRbdNbdToolFeatures() {
var stderr string
// check if the module is loaded or compiled in
_, err := os.Stat(fmt.Sprintf("/sys/module/%s", moduleNbd))
_, err := os.Stat("/sys/module/" + moduleNbd)
if os.IsNotExist(err) {
// try to load the module
_, stderr, err = util.ExecCommand(context.TODO(), "modprobe", moduleNbd)
@ -377,7 +377,7 @@ func appendNbdDeviceTypeAndOptions(cmdArgs []string, userOptions, cookie string)
}
if hasNBDCookieSupport {
cmdArgs = append(cmdArgs, fmt.Sprintf("--cookie=%s", cookie))
cmdArgs = append(cmdArgs, "--cookie="+cookie)
}
}
@ -409,7 +409,7 @@ func appendKRbdDeviceTypeAndOptions(cmdArgs []string, userOptions string) []stri
// provided for rbd integrated cli to rbd-nbd cli format specific.
func appendRbdNbdCliOptions(cmdArgs []string, userOptions, cookie string) []string {
if !strings.Contains(userOptions, useNbdNetlink) {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", useNbdNetlink))
cmdArgs = append(cmdArgs, "--"+useNbdNetlink)
}
if !strings.Contains(userOptions, setNbdReattach) {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdReattach, defaultNbdReAttachTimeout))
@ -418,12 +418,12 @@ func appendRbdNbdCliOptions(cmdArgs []string, userOptions, cookie string) []stri
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdIOTimeout, defaultNbdIOTimeout))
}
if hasNBDCookieSupport {
cmdArgs = append(cmdArgs, fmt.Sprintf("--cookie=%s", cookie))
cmdArgs = append(cmdArgs, "--cookie="+cookie)
}
if userOptions != "" {
options := strings.Split(userOptions, ",")
for _, opt := range options {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", opt))
cmdArgs = append(cmdArgs, "--"+opt)
}
}
@ -566,7 +566,7 @@ func detachRBDImageOrDeviceSpec(
return err
}
if len(mapper) > 0 {
if mapper != "" {
// mapper found, so it is open Luks device
err = util.CloseEncryptedVolume(ctx, mapperFile)
if err != nil {

View File

@ -19,6 +19,7 @@ package rbd
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"path/filepath"
"sync"
@ -79,7 +80,7 @@ func getSecret(c *k8s.Clientset, ns, name string) (map[string]string, error) {
func formatStagingTargetPath(c *k8s.Clientset, pv *v1.PersistentVolume, stagingPath string) (string, error) {
// Kubernetes 1.24+ uses a hash of the volume-id in the path name
unique := sha256.Sum256([]byte(pv.Spec.CSI.VolumeHandle))
targetPath := filepath.Join(stagingPath, pv.Spec.CSI.Driver, fmt.Sprintf("%x", unique), "globalmount")
targetPath := filepath.Join(stagingPath, pv.Spec.CSI.Driver, hex.EncodeToString(unique[:]), "globalmount")
major, minor, err := kubeclient.GetServerVersion(c)
if err != nil {

View File

@ -294,7 +294,7 @@ func (rv *rbdVolume) Exists(ctx context.Context, parentVol *rbdVolume) (bool, er
// NOTE: Return volsize should be on-disk volsize, not request vol size, so
// save it for size checks before fetching image data
requestSize := rv.VolSize //nolint:ifshort // FIXME: rename and split function into helpers
requestSize := rv.VolSize
// Fetch on-disk image attributes and compare against request
err = rv.getImageInfo()
if err != nil {

View File

@ -1163,8 +1163,6 @@ func generateVolumeFromVolumeID(
// GenVolFromVolID generates a rbdVolume structure from the provided identifier, updating
// the structure with elements from on-disk image metadata as well.
//
//nolint:golint // TODO: returning unexported rbdVolume type, use an interface instead.
func GenVolFromVolID(
ctx context.Context,
volumeID string,
@ -1231,7 +1229,7 @@ func generateVolumeFromMapping(
vi.ClusterID)
// Add mapping clusterID to Identifier
nvi.ClusterID = mappedClusterID
poolID := fmt.Sprintf("%d", (vi.LocationID))
poolID := strconv.FormatInt(vi.LocationID, 10)
for _, pools := range cm.RBDpoolIDMappingInfo {
for key, val := range pools {
mappedPoolID := util.GetMappedID(key, val, poolID)
@ -1525,7 +1523,7 @@ func (rv *rbdVolume) setImageOptions(ctx context.Context, options *librbd.ImageO
logMsg := fmt.Sprintf("setting image options on %s", rv)
if rv.DataPool != "" {
logMsg += fmt.Sprintf(", data pool %s", rv.DataPool)
logMsg += ", data pool %s" + rv.DataPool
err = options.SetString(librbd.RbdImageOptionDataPool, rv.DataPool)
if err != nil {
return fmt.Errorf("failed to set data pool: %w", err)

View File

@ -24,7 +24,7 @@ import (
"testing"
librbd "github.com/ceph/go-ceph/rbd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHasSnapshotFeature(t *testing.T) {
@ -165,11 +165,11 @@ func TestValidateImageFeatures(t *testing.T) {
for _, test := range tests {
err := test.rbdVol.validateImageFeatures(test.imageFeatures)
if test.isErr {
assert.EqualError(t, err, test.errMsg)
require.EqualError(t, err, test.errMsg)
continue
}
assert.Nil(t, err)
require.NoError(t, err)
}
}

View File

@ -48,7 +48,7 @@ func ExecuteCommandWithNSEnter(ctx context.Context, netPath, program string, arg
return "", "", fmt.Errorf("failed to get stat for %s %w", netPath, err)
}
// nsenter --net=%s -- <program> <args>
args = append([]string{fmt.Sprintf("--net=%s", netPath), "--", program}, args...)
args = append([]string{"--net=" + netPath, "--", program}, args...)
sanitizedArgs := StripSecretInArgs(args)
cmd := exec.Command(nsenter, args...) // #nosec:G204, commands executing not vulnerable.
cmd.Stdout = &stdoutBuf

View File

@ -159,7 +159,7 @@ func TestGetClusterMappingInfo(t *testing.T) {
})
}
clusterMappingConfigFile = fmt.Sprintf("%s/mapping.json", mappingBasePath)
clusterMappingConfigFile = mappingBasePath + "/mapping.json"
err = os.WriteFile(clusterMappingConfigFile, mappingFileContent, 0o600)
if err != nil {
t.Errorf("failed to write mapping content error = %v", err)

View File

@ -19,7 +19,7 @@ package util
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_getCrushLocationMap(t *testing.T) {
@ -105,7 +105,7 @@ func Test_getCrushLocationMap(t *testing.T) {
currentTT := tt
t.Run(currentTT.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t,
require.Equal(t,
currentTT.want,
getCrushLocationMap(currentTT.args.crushLocationLabels, currentTT.args.nodeLabels))
})

View File

@ -23,7 +23,6 @@ import (
"github.com/ceph/ceph-csi/internal/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -35,8 +34,8 @@ func TestGenerateNewEncryptionPassphrase(t *testing.T) {
// b64Passphrase is URL-encoded, decode to verify the length of the
// passphrase
passphrase, err := base64.URLEncoding.DecodeString(b64Passphrase)
assert.NoError(t, err)
assert.Equal(t, defaultEncryptionPassphraseSize, len(passphrase))
require.NoError(t, err)
require.Len(t, passphrase, defaultEncryptionPassphraseSize)
}
func TestKMSWorkflow(t *testing.T) {
@ -47,52 +46,52 @@ func TestKMSWorkflow(t *testing.T) {
}
kmsProvider, err := kms.GetDefaultKMS(secrets)
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, kmsProvider)
ve, err := NewVolumeEncryption("", kmsProvider)
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, ve)
assert.Equal(t, kms.DefaultKMSType, ve.GetID())
require.Equal(t, kms.DefaultKMSType, ve.GetID())
volumeID := "volume-id"
ctx := context.TODO()
err = ve.StoreNewCryptoPassphrase(ctx, volumeID, defaultEncryptionPassphraseSize)
assert.NoError(t, err)
require.NoError(t, err)
passphrase, err := ve.GetCryptoPassphrase(ctx, volumeID)
assert.NoError(t, err)
assert.Equal(t, secrets["encryptionPassphrase"], passphrase)
require.NoError(t, err)
require.Equal(t, secrets["encryptionPassphrase"], passphrase)
}
func TestEncryptionType(t *testing.T) {
t.Parallel()
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("wat?"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("both"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("file,block"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("block,file"))
assert.EqualValues(t, EncryptionTypeBlock, ParseEncryptionType("block"))
assert.EqualValues(t, EncryptionTypeFile, ParseEncryptionType("file"))
assert.EqualValues(t, EncryptionTypeNone, ParseEncryptionType(""))
require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("wat?"))
require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("both"))
require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("file,block"))
require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("block,file"))
require.EqualValues(t, EncryptionTypeBlock, ParseEncryptionType("block"))
require.EqualValues(t, EncryptionTypeFile, ParseEncryptionType("file"))
require.EqualValues(t, EncryptionTypeNone, ParseEncryptionType(""))
for _, s := range []string{"file", "block", ""} {
assert.EqualValues(t, s, ParseEncryptionType(s).String())
require.EqualValues(t, s, ParseEncryptionType(s).String())
}
}
func TestFetchEncryptionType(t *testing.T) {
t.Parallel()
volOpts := map[string]string{}
assert.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeBlock))
assert.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeFile))
assert.EqualValues(t, EncryptionTypeNone, FetchEncryptionType(volOpts, EncryptionTypeNone))
require.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeBlock))
require.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeFile))
require.EqualValues(t, EncryptionTypeNone, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = ""
assert.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
require.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "block"
assert.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeNone))
require.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "file"
assert.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeNone))
require.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "INVALID"
assert.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
require.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
}

View File

@ -463,5 +463,5 @@ func Unlock(
return initializeAndUnlock(ctx, fscryptContext, encryptedPath, protectorName, keyFn)
}
return fmt.Errorf("unsupported")
return errors.New("unsupported")
}

View File

@ -20,7 +20,7 @@ import (
kmsapi "github.com/ceph/ceph-csi/internal/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetPassphraseFromKMS(t *testing.T) {
@ -31,7 +31,7 @@ func TestGetPassphraseFromKMS(t *testing.T) {
continue
}
kms := kmsapi.GetKMSTestDummy(provider.UniqueID)
assert.NotNil(t, kms)
require.NotNil(t, kms)
volEnc, err := NewVolumeEncryption(provider.UniqueID, kms)
if errors.Is(err, ErrDEKStoreNeeded) {
@ -40,14 +40,14 @@ func TestGetPassphraseFromKMS(t *testing.T) {
continue // currently unsupported by fscrypt integration
}
}
assert.NotNil(t, volEnc)
require.NotNil(t, volEnc)
if kms.RequiresDEKStore() == kmsapi.DEKStoreIntegrated {
continue
}
secret, err := kms.GetSecret(context.TODO(), "")
assert.NoError(t, err, provider.UniqueID)
assert.NotEmpty(t, secret, provider.UniqueID)
require.NoError(t, err, provider.UniqueID)
require.NotEmpty(t, secret, provider.UniqueID)
}
}

View File

@ -70,7 +70,7 @@ func getCgroupPidsFile() (string, error) {
}
}
if slice == "" {
return "", fmt.Errorf("could not find a cgroup for 'pids'")
return "", errors.New("could not find a cgroup for 'pids'")
}
return pidsMax, nil
@ -112,7 +112,7 @@ func GetPIDLimit() (int, error) {
func SetPIDLimit(limit int) error {
limitStr := "max"
if limit != -1 {
limitStr = fmt.Sprintf("%d", limit)
limitStr = strconv.Itoa(limit)
}
pidsMax, err := getCgroupPidsFile()

View File

@ -19,7 +19,7 @@ package util
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadAffinity_ConstructReadAffinityMapOption(t *testing.T) {
@ -62,7 +62,7 @@ func TestReadAffinity_ConstructReadAffinityMapOption(t *testing.T) {
currentTT := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Contains(t, currentTT.wantAny, ConstructReadAffinityMapOption(currentTT.crushLocationmap))
require.Contains(t, currentTT.wantAny, ConstructReadAffinityMapOption(currentTT.crushLocationmap))
})
}
}

View File

@ -70,7 +70,6 @@ func TryRADOSAborted(opErr error) error {
return opErr
}
//nolint:errorlint // Can't use errors.As() because rados.radosError is private.
errnoErr, ok := radosOpErr.OpError.(interface{ ErrorCode() int })
if !ok {
return opErr

View File

@ -22,7 +22,7 @@ import (
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/ceph/ceph-csi/internal/util/reftracker/reftype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const rtName = "hello-rt"
@ -36,8 +36,8 @@ func TestRTAdd(t *testing.T) {
ioctx := radoswrapper.NewFakeIOContext(radoswrapper.NewFakeRados())
created, err := Add(ioctx, "", nil)
assert.Error(ts, err)
assert.False(ts, created)
require.Error(ts, err)
require.False(ts, created)
})
// Verify input validation for nil and empty refs.
@ -51,8 +51,8 @@ func TestRTAdd(t *testing.T) {
}
for _, ref := range refs {
created, err := Add(ioctx, rtName, ref)
assert.Error(ts, err)
assert.False(ts, created)
require.Error(ts, err)
require.False(ts, created)
}
})
@ -66,8 +66,8 @@ func TestRTAdd(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
})
// Add refs where each Add() has some of the refs overlapping
@ -80,8 +80,8 @@ func TestRTAdd(t *testing.T) {
"ref1": {},
"ref2": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
refsTable := []map[string]struct{}{
{"ref2": {}, "ref3": {}},
@ -90,8 +90,8 @@ func TestRTAdd(t *testing.T) {
}
for _, refs := range refsTable {
created, err = Add(ioctx, rtName, refs)
assert.NoError(ts, err)
assert.False(ts, created)
require.NoError(ts, err)
require.False(ts, created)
}
})
}
@ -110,8 +110,8 @@ func TestRTRemove(t *testing.T) {
}
for _, ref := range refs {
created, err := Remove(ioctx, rtName, ref)
assert.Error(ts, err)
assert.False(ts, created)
require.Error(ts, err)
require.False(ts, created)
}
})
@ -124,8 +124,8 @@ func TestRTRemove(t *testing.T) {
deleted, err := Remove(ioctx, "xxx", map[string]reftype.RefType{
"ref1": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Removing only non-existent refs should not result in reftracker object
@ -140,16 +140,16 @@ func TestRTRemove(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"refX": reftype.Normal,
"refY": reftype.Normal,
"refZ": reftype.Normal,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
})
// Removing all refs plus some surplus should result in reftracker object
@ -162,8 +162,8 @@ func TestRTRemove(t *testing.T) {
created, err := Add(ioctx, rtName, map[string]struct{}{
"ref": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"refX": reftype.Normal,
@ -171,8 +171,8 @@ func TestRTRemove(t *testing.T) {
"ref": reftype.Normal,
"refZ": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Bulk removal of all refs should result in reftracker object deletion.
@ -189,12 +189,12 @@ func TestRTRemove(t *testing.T) {
}
created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Removal of all refs one-by-one should result in reftracker object deletion
@ -209,23 +209,23 @@ func TestRTRemove(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
for _, k := range []string{"ref3", "ref2"} {
deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{
k: reftype.Normal,
})
assert.NoError(ts, errRemove)
assert.False(ts, deleted)
require.NoError(ts, errRemove)
require.False(ts, deleted)
}
// Remove the last reference. It should remove the whole reftracker object too.
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Cycle through reftracker object twice.
@ -246,12 +246,12 @@ func TestRTRemove(t *testing.T) {
for i := 0; i < 2; i++ {
created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
}
})
@ -265,8 +265,8 @@ func TestRTRemove(t *testing.T) {
"ref1": {},
"ref2": {},
})
assert.True(ts, created)
assert.NoError(ts, err)
require.True(ts, created)
require.NoError(ts, err)
refsTable := []map[string]struct{}{
{"ref2": {}, "ref3": {}},
{"ref3": {}, "ref4": {}},
@ -274,8 +274,8 @@ func TestRTRemove(t *testing.T) {
}
for _, refs := range refsTable {
created, err = Add(ioctx, rtName, refs)
assert.False(ts, created)
assert.NoError(ts, err)
require.False(ts, created)
require.NoError(ts, err)
}
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
@ -285,8 +285,8 @@ func TestRTRemove(t *testing.T) {
"ref4": reftype.Normal,
"ref5": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
}
@ -307,12 +307,12 @@ func TestRTMask(t *testing.T) {
}
created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Masking all refs one-by-one should result in reftracker object deletion in
@ -327,15 +327,15 @@ func TestRTMask(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
for _, k := range []string{"ref3", "ref2"} {
deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{
k: reftype.Mask,
})
assert.NoError(ts, errRemove)
assert.False(ts, deleted)
require.NoError(ts, errRemove)
require.False(ts, deleted)
}
// Remove the last reference. It should delete the whole reftracker object
@ -343,8 +343,8 @@ func TestRTMask(t *testing.T) {
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Bulk removing two (out of 3) refs and then masking the ref that's left
@ -359,21 +359,21 @@ func TestRTMask(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal,
"ref2": reftype.Normal,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Mask,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Bulk masking two (out of 3) refs and then removing the ref that's left
@ -388,21 +388,21 @@ func TestRTMask(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask,
"ref2": reftype.Mask,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Verify that masking refs hides them from future Add()s.
@ -416,28 +416,28 @@ func TestRTMask(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask,
"ref2": reftype.Mask,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
created, err = Add(ioctx, rtName, map[string]struct{}{
"ref1": {},
"ref2": {},
})
assert.NoError(ts, err)
assert.False(ts, created)
require.NoError(ts, err)
require.False(ts, created)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
// Verify that masked refs may be removed with reftype.Normal and re-added.
@ -451,41 +451,41 @@ func TestRTMask(t *testing.T) {
"ref2": {},
"ref3": {},
})
assert.NoError(ts, err)
assert.True(ts, created)
require.NoError(ts, err)
require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask,
"ref2": reftype.Mask,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal,
"ref2": reftype.Normal,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
created, err = Add(ioctx, rtName, map[string]struct{}{
"ref1": {},
"ref2": {},
})
assert.NoError(ts, err)
assert.False(ts, created)
require.NoError(ts, err)
require.False(ts, created)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal,
})
assert.NoError(ts, err)
assert.False(ts, deleted)
require.NoError(ts, err)
require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal,
"ref2": reftype.Normal,
})
assert.NoError(ts, err)
assert.True(ts, deleted)
require.NoError(ts, err)
require.True(ts, deleted)
})
}

View File

@ -19,7 +19,7 @@ package reftype
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRefTypeBytes(t *testing.T) {
@ -41,7 +41,7 @@ func TestRefTypeBytes(t *testing.T) {
for i := range expectedBytes {
bs := ToBytes(refTypes[i])
assert.Equal(ts, expectedBytes[i], bs)
require.Equal(ts, expectedBytes[i], bs)
}
})
@ -50,14 +50,14 @@ func TestRefTypeBytes(t *testing.T) {
for i := range refTypes {
refType, err := FromBytes(expectedBytes[i])
assert.NoError(ts, err)
assert.Equal(ts, refTypes[i], refType)
require.NoError(ts, err)
require.Equal(ts, refTypes[i], refType)
}
_, err := FromBytes(refTypeInvalidBytes)
assert.Error(ts, err)
require.Error(ts, err)
_, err = FromBytes(refTypeWrongSizeBytes)
assert.Error(ts, err)
require.Error(ts, err)
})
}

View File

@ -19,7 +19,7 @@ package v1
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestV1RefCountBytes(t *testing.T) {
@ -35,17 +35,17 @@ func TestV1RefCountBytes(t *testing.T) {
ts.Parallel()
bs := refCountValue.toBytes()
assert.Equal(ts, refCountBytes, bs)
require.Equal(ts, refCountBytes, bs)
})
t.Run("FromBytes", func(ts *testing.T) {
ts.Parallel()
rc, err := refCountFromBytes(refCountBytes)
assert.NoError(ts, err)
assert.Equal(ts, refCountValue, rc)
require.NoError(ts, err)
require.Equal(ts, refCountValue, rc)
_, err = refCountFromBytes(wrongSizeRefCountBytes)
assert.Error(ts, err)
require.Error(ts, err)
})
}

View File

@ -205,7 +205,7 @@ func Remove(
if rcToSubtract > readRes.total {
// BUG: this should never happen!
return false, fmt.Errorf("refcount underflow, reftracker object corrupted")
return false, goerrors.New("refcount underflow, reftracker object corrupted")
}
newRC := readRes.total - rcToSubtract

View File

@ -17,14 +17,13 @@ limitations under the License.
package v1
import (
goerrors "errors"
"testing"
"github.com/ceph/ceph-csi/internal/util/reftracker/errors"
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/ceph/ceph-csi/internal/util/reftracker/reftype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestV1Read(t *testing.T) {
@ -73,17 +72,17 @@ func TestV1Read(t *testing.T) {
)
err := Add(validObj, rtName, gen, refsToAdd)
assert.NoError(t, err)
require.NoError(t, err)
for i := range invalidObjs {
err = Add(invalidObjs[i], rtName, gen, refsToAdd)
assert.Error(t, err)
require.Error(t, err)
}
// Check for correct error type for wrong gen num.
err = Add(invalidObjs[1], rtName, gen, refsToAdd)
assert.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate))
require.Error(t, err)
require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
}
func TestV1Init(t *testing.T) {
@ -106,10 +105,10 @@ func TestV1Init(t *testing.T) {
)
err := Init(emptyRados, rtName, refsToInit)
assert.NoError(t, err)
require.NoError(t, err)
err = Init(alreadyExists, rtName, refsToInit)
assert.Error(t, err)
require.Error(t, err)
}
func TestV1Add(t *testing.T) {
@ -224,19 +223,19 @@ func TestV1Add(t *testing.T) {
ioctx.Rados.Objs[rtName] = shouldSucceed[i].before
err := Add(ioctx, rtName, 0, shouldSucceed[i].refsToAdd)
assert.NoError(t, err)
assert.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
require.NoError(t, err)
require.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
}
for i := range shouldFail {
err := Add(shouldFail[i], rtName, 0, map[string]struct{}{"ref1": {}})
assert.Error(t, err)
require.Error(t, err)
}
// Check for correct error type for wrong gen num.
err := Add(shouldFail[1], rtName, 0, map[string]struct{}{"ref1": {}})
assert.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate))
require.Error(t, err)
require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
}
func TestV1Remove(t *testing.T) {
@ -412,12 +411,12 @@ func TestV1Remove(t *testing.T) {
ioctx.Rados.Objs[rtName] = shouldSucceed[i].before
deleted, err := Remove(ioctx, rtName, 0, shouldSucceed[i].refsToRemove)
assert.NoError(t, err)
assert.Equal(t, shouldSucceed[i].deleted, deleted)
assert.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
require.NoError(t, err)
require.Equal(t, shouldSucceed[i].deleted, deleted)
require.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
}
_, err := Remove(badGen, rtName, 0, map[string]reftype.RefType{"ref": reftype.Normal})
assert.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate))
require.Error(t, err)
require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
}

View File

@ -21,7 +21,7 @@ import (
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
@ -38,18 +38,18 @@ func TestVersionBytes(t *testing.T) {
ts.Parallel()
bs := ToBytes(v1Value)
assert.Equal(ts, v1Bytes, bs)
require.Equal(ts, v1Bytes, bs)
})
t.Run("FromBytes", func(ts *testing.T) {
ts.Parallel()
ver, err := FromBytes(v1Bytes)
assert.NoError(ts, err)
assert.Equal(ts, v1Value, ver)
require.NoError(ts, err)
require.Equal(ts, v1Value, ver)
_, err = FromBytes(wrongSizeVersionBytes)
assert.Error(ts, err)
require.Error(ts, err)
})
}
@ -101,11 +101,11 @@ func TestVersionRead(t *testing.T) {
)
ver, err := Read(validObj, rtName)
assert.NoError(t, err)
assert.Equal(t, v1Value, ver)
require.NoError(t, err)
require.Equal(t, v1Value, ver)
for i := range invalidObjs {
_, err = Read(invalidObjs[i], rtName)
assert.Error(t, err)
require.Error(t, err)
}
}

View File

@ -212,7 +212,7 @@ func parseKernelRelease(release string) (int, int, int, int, error) {
extraversion := 0
if n > minVersions {
n, err = fmt.Sscanf(extra, ".%d%s", &sublevel, &extra)
if err != nil && n == 0 && len(extra) > 0 && extra[0] != '-' && extra[0] == '.' {
if err != nil && n == 0 && extra != "" && extra[0] != '-' && extra[0] == '.' {
return 0, 0, 0, 0, fmt.Errorf("failed to parse subversion from %s: %w", release, err)
}

View File

@ -87,7 +87,7 @@ func ValidateNodeUnpublishVolumeRequest(req *csi.NodeUnpublishVolumeRequest) err
// volume is from source as empty ReadOnlyMany is not supported.
func CheckReadOnlyManyIsSupported(req *csi.CreateVolumeRequest) error {
for _, capability := range req.GetVolumeCapabilities() {
if m := capability.GetAccessMode().Mode; m == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
if m := capability.GetAccessMode().GetMode(); m == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
m == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
if req.GetVolumeContentSource() == nil {
return status.Error(codes.InvalidArgument, "readOnly accessMode is supported only with content source")

View File

@ -1,5 +1,5 @@
---
# https://github.com/golangci/golangci-lint/blob/master/.golangci.example.yml
# https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
# This file contains all available configuration options
# with their default values.
@ -12,7 +12,7 @@ run:
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m
timeout: 20m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1

View File

@ -33,6 +33,9 @@
#define CONSTBASE R16
#define BLOCKS R17
// for VPERMXOR
#define MASK R18
DATA consts<>+0x00(SB)/8, $0x3320646e61707865
DATA consts<>+0x08(SB)/8, $0x6b20657479622d32
DATA consts<>+0x10(SB)/8, $0x0000000000000001
@ -53,7 +56,11 @@ DATA consts<>+0x80(SB)/8, $0x6b2065746b206574
DATA consts<>+0x88(SB)/8, $0x6b2065746b206574
DATA consts<>+0x90(SB)/8, $0x0000000100000000
DATA consts<>+0x98(SB)/8, $0x0000000300000002
GLOBL consts<>(SB), RODATA, $0xa0
DATA consts<>+0xa0(SB)/8, $0x5566774411223300
DATA consts<>+0xa8(SB)/8, $0xddeeffcc99aabb88
DATA consts<>+0xb0(SB)/8, $0x6677445522330011
DATA consts<>+0xb8(SB)/8, $0xeeffccddaabb8899
GLOBL consts<>(SB), RODATA, $0xc0
//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
@ -70,6 +77,9 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $48, R10
MOVD $64, R11
SRD $6, LEN, BLOCKS
// for VPERMXOR
MOVD $consts<>+0xa0(SB), MASK
MOVD $16, R20
// V16
LXVW4X (CONSTBASE)(R0), VS48
ADD $80,CONSTBASE
@ -87,6 +97,10 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
// V28
LXVW4X (CONSTBASE)(R11), VS60
// Load mask constants for VPERMXOR
LXVW4X (MASK)(R0), V20
LXVW4X (MASK)(R20), V21
// splat slot from V19 -> V26
VSPLTW $0, V19, V26
@ -97,7 +111,7 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $10, R14
MOVD R14, CTR
PCALIGN $16
loop_outer_vsx:
// V0, V1, V2, V3
LXVW4X (R0)(CONSTBASE), VS32
@ -128,22 +142,17 @@ loop_outer_vsx:
VSPLTISW $12, V28
VSPLTISW $8, V29
VSPLTISW $7, V30
PCALIGN $16
loop_vsx:
VADDUWM V0, V4, V0
VADDUWM V1, V5, V1
VADDUWM V2, V6, V2
VADDUWM V3, V7, V3
VXOR V12, V0, V12
VXOR V13, V1, V13
VXOR V14, V2, V14
VXOR V15, V3, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VRLW V15, V27, V15
VPERMXOR V12, V0, V21, V12
VPERMXOR V13, V1, V21, V13
VPERMXOR V14, V2, V21, V14
VPERMXOR V15, V3, V21, V15
VADDUWM V8, V12, V8
VADDUWM V9, V13, V9
@ -165,15 +174,10 @@ loop_vsx:
VADDUWM V2, V6, V2
VADDUWM V3, V7, V3
VXOR V12, V0, V12
VXOR V13, V1, V13
VXOR V14, V2, V14
VXOR V15, V3, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VRLW V15, V29, V15
VPERMXOR V12, V0, V20, V12
VPERMXOR V13, V1, V20, V13
VPERMXOR V14, V2, V20, V14
VPERMXOR V15, V3, V20, V15
VADDUWM V8, V12, V8
VADDUWM V9, V13, V9
@ -195,15 +199,10 @@ loop_vsx:
VADDUWM V2, V7, V2
VADDUWM V3, V4, V3
VXOR V15, V0, V15
VXOR V12, V1, V12
VXOR V13, V2, V13
VXOR V14, V3, V14
VRLW V15, V27, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VPERMXOR V15, V0, V21, V15
VPERMXOR V12, V1, V21, V12
VPERMXOR V13, V2, V21, V13
VPERMXOR V14, V3, V21, V14
VADDUWM V10, V15, V10
VADDUWM V11, V12, V11
@ -225,15 +224,10 @@ loop_vsx:
VADDUWM V2, V7, V2
VADDUWM V3, V4, V3
VXOR V15, V0, V15
VXOR V12, V1, V12
VXOR V13, V2, V13
VXOR V14, V3, V14
VRLW V15, V29, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VPERMXOR V15, V0, V20, V15
VPERMXOR V12, V1, V20, V12
VPERMXOR V13, V2, V20, V13
VPERMXOR V14, V3, V20, V14
VADDUWM V10, V15, V10
VADDUWM V11, V12, V11
@ -249,48 +243,48 @@ loop_vsx:
VRLW V6, V30, V6
VRLW V7, V30, V7
VRLW V4, V30, V4
BC 16, LT, loop_vsx
BDNZ loop_vsx
VADDUWM V12, V26, V12
WORD $0x13600F8C // VMRGEW V0, V1, V27
WORD $0x13821F8C // VMRGEW V2, V3, V28
VMRGEW V0, V1, V27
VMRGEW V2, V3, V28
WORD $0x10000E8C // VMRGOW V0, V1, V0
WORD $0x10421E8C // VMRGOW V2, V3, V2
VMRGOW V0, V1, V0
VMRGOW V2, V3, V2
WORD $0x13A42F8C // VMRGEW V4, V5, V29
WORD $0x13C63F8C // VMRGEW V6, V7, V30
VMRGEW V4, V5, V29
VMRGEW V6, V7, V30
XXPERMDI VS32, VS34, $0, VS33
XXPERMDI VS32, VS34, $3, VS35
XXPERMDI VS59, VS60, $0, VS32
XXPERMDI VS59, VS60, $3, VS34
WORD $0x10842E8C // VMRGOW V4, V5, V4
WORD $0x10C63E8C // VMRGOW V6, V7, V6
VMRGOW V4, V5, V4
VMRGOW V6, V7, V6
WORD $0x13684F8C // VMRGEW V8, V9, V27
WORD $0x138A5F8C // VMRGEW V10, V11, V28
VMRGEW V8, V9, V27
VMRGEW V10, V11, V28
XXPERMDI VS36, VS38, $0, VS37
XXPERMDI VS36, VS38, $3, VS39
XXPERMDI VS61, VS62, $0, VS36
XXPERMDI VS61, VS62, $3, VS38
WORD $0x11084E8C // VMRGOW V8, V9, V8
WORD $0x114A5E8C // VMRGOW V10, V11, V10
VMRGOW V8, V9, V8
VMRGOW V10, V11, V10
WORD $0x13AC6F8C // VMRGEW V12, V13, V29
WORD $0x13CE7F8C // VMRGEW V14, V15, V30
VMRGEW V12, V13, V29
VMRGEW V14, V15, V30
XXPERMDI VS40, VS42, $0, VS41
XXPERMDI VS40, VS42, $3, VS43
XXPERMDI VS59, VS60, $0, VS40
XXPERMDI VS59, VS60, $3, VS42
WORD $0x118C6E8C // VMRGOW V12, V13, V12
WORD $0x11CE7E8C // VMRGOW V14, V15, V14
VMRGOW V12, V13, V12
VMRGOW V14, V15, V14
VSPLTISW $4, V27
VADDUWM V26, V27, V26
@ -431,7 +425,7 @@ tail_vsx:
ADD $-1, R11, R12
ADD $-1, INP
ADD $-1, OUT
PCALIGN $16
looptail_vsx:
// Copying the result to OUT
// in bytes.
@ -439,7 +433,7 @@ looptail_vsx:
MOVBZU 1(INP), TMP
XOR KEY, TMP, KEY
MOVBU KEY, 1(OUT)
BC 16, LT, looptail_vsx
BDNZ looptail_vsx
// Clear the stack values
STXVW4X VS48, (R11)(R0)

View File

@ -426,6 +426,35 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]"
}
// ServerAuthCallbacks defines server-side authentication callbacks.
type ServerAuthCallbacks struct {
// PasswordCallback behaves like [ServerConfig.PasswordCallback].
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
GSSAPIWithMICConfig *GSSAPIWithMICConfig
}
// PartialSuccessError can be returned by any of the [ServerConfig]
// authentication callbacks to indicate to the client that authentication has
// partially succeeded, but further steps are required.
type PartialSuccessError struct {
// Next defines the authentication callbacks to apply to further steps. The
// available methods communicated to the client are based on the non-nil
// ServerAuthCallbacks fields.
Next ServerAuthCallbacks
}
func (p *PartialSuccessError) Error() string {
return "ssh: authenticated with partial success"
}
// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
@ -439,8 +468,18 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
var perms *Permissions
authFailures := 0
noneAuthCount := 0
var authErrs []error
var displayedBanner bool
partialSuccessReturned := false
// Set the initial authentication callbacks from the config. They can be
// changed if a PartialSuccessError is returned.
authConfig := ServerAuthCallbacks{
PasswordCallback: config.PasswordCallback,
PublicKeyCallback: config.PublicKeyCallback,
KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
GSSAPIWithMICConfig: config.GSSAPIWithMICConfig,
}
userAuthLoop:
for {
@ -471,6 +510,11 @@ userAuthLoop:
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
if s.user != userAuthReq.User && partialSuccessReturned {
return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
s.user, userAuthReq.User)
}
s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
@ -491,20 +535,18 @@ userAuthLoop:
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
noneAuthCount++
// We don't allow none authentication after a partial success
// response.
if config.NoClientAuth && !partialSuccessReturned {
if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s)
} else {
authErr = nil
}
}
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password":
if config.PasswordCallback == nil {
if authConfig.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
break
}
@ -518,17 +560,17 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest)
}
perms, authErr = config.PasswordCallback(s, password)
perms, authErr = authConfig.PasswordCallback(s, password)
case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil {
if authConfig.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configured")
break
}
prompter := &sshClientKeyboardInteractive{s}
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey":
if config.PublicKeyCallback == nil {
if authConfig.PublicKeyCallback == nil {
authErr = errors.New("ssh: publickey auth not configured")
break
}
@ -562,11 +604,18 @@ userAuthLoop:
if !ok {
candidate.user = s.user
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
candidate.result = checkSourceAddress(
candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if (candidate.result == nil || isPartialSuccessError) &&
candidate.perms != nil &&
candidate.perms.CriticalOptions != nil &&
candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
if err := checkSourceAddress(
s.RemoteAddr(),
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
candidate.result = err
}
}
cache.add(candidate)
}
@ -578,8 +627,8 @@ userAuthLoop:
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
if candidate.result == nil {
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if candidate.result == nil || isPartialSuccessError {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: pubKeyData,
@ -629,11 +678,11 @@ userAuthLoop:
perms = candidate.perms
}
case "gssapi-with-mic":
if config.GSSAPIWithMICConfig == nil {
if authConfig.GSSAPIWithMICConfig == nil {
authErr = errors.New("ssh: gssapi-with-mic auth not configured")
break
}
gssapiConfig := config.GSSAPIWithMICConfig
gssapiConfig := authConfig.GSSAPIWithMICConfig
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
if err != nil {
return nil, parseError(msgUserAuthRequest)
@ -689,49 +738,70 @@ userAuthLoop:
break userAuthLoop
}
authFailures++
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the
// final SSH_MSG_USERAUTH_FAILURE message, since there are
// no more authentication methods which can be attempted,
// and this message may cause the client to re-attempt
// authentication while we send the disconnect message.
// Continue, and trigger the disconnect at the start of
// the loop.
//
// The SSH specification is somewhat confusing about this,
// RFC 4252 Section 5.1 requires each authentication failure
// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
// message, but Section 4 says the server should disconnect
// after some number of attempts, but it isn't explicit which
// message should take precedence (i.e. should there be a failure
// message than a disconnect message, or if we are going to
// disconnect, should we only send that message.)
//
// Either way, OpenSSH disconnects immediately after the last
// failed authnetication attempt, and given they are typically
// considered the golden implementation it seems reasonable
// to match that behavior.
continue
var failureMsg userAuthFailureMsg
if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
// After a partial success error we don't allow changing the user
// name and execute the NoClientAuthCallback.
partialSuccessReturned = true
// In case a partial success is returned, the server may send
// a new set of authentication methods.
authConfig = partialSuccess.Next
// Reset pubkey cache, as the new PublicKeyCallback might
// accept a different set of public keys.
cache = pubKeyCache{}
// Send back a partial success message to the user.
failureMsg.PartialSuccess = true
} else {
// Allow initial attempt of 'none' without penalty.
if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 {
authFailures++
}
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the
// final SSH_MSG_USERAUTH_FAILURE message, since there are
// no more authentication methods which can be attempted,
// and this message may cause the client to re-attempt
// authentication while we send the disconnect message.
// Continue, and trigger the disconnect at the start of
// the loop.
//
// The SSH specification is somewhat confusing about this,
// RFC 4252 Section 5.1 requires each authentication failure
// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
// message, but Section 4 says the server should disconnect
// after some number of attempts, but it isn't explicit which
// message should take precedence (i.e. should there be a failure
// message than a disconnect message, or if we are going to
// disconnect, should we only send that message.)
//
// Either way, OpenSSH disconnects immediately after the last
// failed authentication attempt, and given they are typically
// considered the golden implementation it seems reasonable
// to match that behavior.
continue
}
}
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
if authConfig.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
if config.PublicKeyCallback != nil {
if authConfig.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
if config.KeyboardInteractiveCallback != nil {
if authConfig.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
}
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
config.GSSAPIWithMICConfig.AllowLogin != nil {
if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
}
if len(failureMsg.Methods) == 0 {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
return nil, errors.New("ssh: no authentication methods available")
}
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {

View File

@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
remainSize = 0
return
}
remainSize -= size
@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
// Avoid parsing large amounts of headers that we will then discard.
// If the sender exceeds the max header list size by too much,
// skip parsing the fragment and close the connection.
//
// "Too much" is either any CONTINUATION frame after we've already
// exceeded the max header list size (in which case remainSize is 0),
// or a frame whose encoded size is more than twice the remaining
// header list bytes we're willing to accept.
if int64(len(frag)) > int64(2*remainSize) {
if VerboseLogs {
log.Printf("http2: header list too large")
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
// Also close the connection after any CONTINUATION frame following an
// invalid header, since we stop tracking the size of the headers after
// an invalid one.
if invalid != nil {
if VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}

View File

@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
}
}
var errClosedPipeWrite = errors.New("write on closed buffer")
var (
errClosedPipeWrite = errors.New("write on closed buffer")
errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
)
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite
}
// pipe.setBuffer is never invoked, leaving the buffer uninitialized.
// We shouldn't try to write to an uninitialized pipe,
// but returning an error is better than panicking.
if p.b == nil {
return 0, errUninitializedPipeWrite
}
return p.b.Write(d)
}

View File

@ -124,6 +124,7 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout.
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow
@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here.
if sc.hs.WriteTimeout != 0 {
if sc.hs.WriteTimeout > 0 {
sc.conn.SetWriteDeadline(time.Time{})
}
@ -924,7 +925,7 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 {
if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 {
if sc.srv.IdleTimeout > 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 {
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package
// prior to the upgrade.
if sc.hs.ReadTimeout != 0 {
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
}
@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout != 0 {
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}

331
vendor/golang.org/x/net/http2/testsync.go generated vendored Normal file
View File

@ -0,0 +1,331 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

View File

@ -147,6 +147,12 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// Zero means no limit.
IdleConnTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
@ -178,6 +184,8 @@ type Transport struct {
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
syncHooks *testSyncHooks
}
func (t *Transport) maxHeaderListSize() uint32 {
@ -302,7 +310,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
idleTimer *time.Timer
idleTimer timer
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
@ -344,6 +352,60 @@ type ClientConn struct {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
syncHooks *testSyncHooks // can be nil
}
// Hook points used for testing.
// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
// goRun starts a new goroutine.
func (cc *ClientConn) goRun(f func()) {
if cc.syncHooks != nil {
cc.syncHooks.goRun(f)
return
}
go f()
}
// condBroadcast is cc.cond.Broadcast.
func (cc *ClientConn) condBroadcast() {
if cc.syncHooks != nil {
cc.syncHooks.condBroadcast(cc.cond)
}
cc.cond.Broadcast()
}
// condWait is cc.cond.Wait.
func (cc *ClientConn) condWait() {
if cc.syncHooks != nil {
cc.syncHooks.condWait(cc.cond)
}
cc.cond.Wait()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (cc *ClientConn) newTimer(d time.Duration) timer {
if cc.syncHooks != nil {
return cc.syncHooks.newTimer(d)
}
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast()
cs.cc.condBroadcast()
}
}
@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
cc.cond.Broadcast()
cc.condBroadcast()
}
}
@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() {
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
cs.cc.goRun(func() {
cs.reqBody.Close()
close(reqBodyClosed)
}()
})
}
type stickyErrWriter struct {
@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port)
}
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
timer := backoffNewTimer(d)
var tm timer
if t.syncHooks != nil {
tm = t.syncHooks.newTimer(d)
t.syncHooks.blockUntil(func() bool {
select {
case <-tm.C():
case <-req.Context().Done():
default:
return false
}
return true
})
} else {
tm = newTimeTimer(d)
}
select {
case <-timer.C:
case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
timer.Stop()
tm.Stop()
err = req.Context().Err()
}
}
@ -658,6 +725,9 @@ func canRetryError(err error) bool {
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.syncHooks != nil {
return t.newClientConn(nil, singleUse, t.syncHooks)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil {
return nil, err
}
return t.newClientConn(tconn, singleUse)
return t.newClientConn(tconn, singleUse, nil)
}
func (t *Transport) newTLSConfig(host string) *tls.Config {
@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives())
return t.newClientConn(c, t.disableKeepAlives(), nil)
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
cc := &ClientConn{
t: t,
tconn: c,
@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
syncHooks: hooks,
}
if hooks != nil {
hooks.newclientconn(cc)
c = cc.tconn
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return nil, cc.werr
}
go cc.readLoop()
cc.goRun(cc.readLoop)
return cc, nil
}
@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
go func() {
cc.goRun(func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled {
break
}
cc.cond.Wait()
cc.condWait()
}
}()
})
shutdownEnterWaitStateHook()
select {
case <-done:
@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.cond.Broadcast()
cc.condBroadcast()
cc.mu.Unlock()
return ctx.Err()
}
@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
cc.cond.Broadcast()
cc.condBroadcast()
cc.mu.Unlock()
cc.closeConn()
}
@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return cc.roundTrip(req, nil)
}
func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context()
cs := &clientStream{
cc: cc,
@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
go cs.doRequest(req)
cc.goRun(func() {
cs.doRequest(req)
})
waitDone := func() error {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.donec:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.donec:
return nil
@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return err
}
if streamf != nil {
streamf(cs)
}
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
timer := cc.newTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.peerClosed:
case <-respHeaderTimer:
case <-respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.peerClosed:
return nil
@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil
}
cc.pendingRequests++
cc.cond.Wait()
cc.condWait()
cc.pendingRequests--
select {
case <-cs.abort:
@ -1871,10 +2015,26 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take)
return take, nil
}
cc.cond.Wait()
cc.condWait()
}
}
func validateHeaders(hdrs http.Header) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
var errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held.
@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
}
}
// Check for any invalid headers and return an error before we
// Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
}
}
if err := validateHeaders(req.Header); err != "" {
return nil, fmt.Errorf("invalid HTTP header %s", err)
}
if err := validateHeaders(req.Trailer); err != "" {
return nil, fmt.Errorf("invalid HTTP trailer %s", err)
}
enumerateHeaders := func(f func(name, value string)) {
@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
cc.cond.Broadcast()
cc.condBroadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err)
}
}
cc.cond.Broadcast()
cc.condBroadcast()
cc.mu.Unlock()
}
@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
var t timer
if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
})
return nil
}
if !cs.firstByte {
if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams {
cs.flow.add(delta)
}
cc.cond.Broadcast()
cc.condBroadcast()
cc.initialWindowSize = s.Val
case SettingHeaderTableSize:
@ -2922,7 +3076,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
return ConnectionError(ErrCodeFlowControl)
}
cc.cond.Broadcast()
cc.condBroadcast()
return nil
}
@ -2964,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
errc := make(chan error, 1)
go func() {
var pingError error
errc := make(chan struct{})
cc.goRun(func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
if err := cc.fr.WritePing(false, p); err != nil {
errc <- err
if pingError = cc.fr.WritePing(false, p); pingError != nil {
close(errc)
return
}
if err := cc.bw.Flush(); err != nil {
errc <- err
if pingError = cc.bw.Flush(); pingError != nil {
close(errc)
return
}
}()
})
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
select {
case <-c:
return nil
case err := <-errc:
return err
case <-errc:
return pingError
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
@ -3150,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}
func (t *Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
// http1 transport, followed by 0
if t.IdleConnTimeout != 0 {
return t.IdleConnTimeout
}
if t.t1 != nil {
return t.t1.IdleConnTimeout
}
return 0
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || openbsd || solaris
//go:build aix || darwin || dragonfly || freebsd || openbsd || solaris || zos
package unix

View File

@ -1520,6 +1520,14 @@ func (m *mmapper) Munmap(data []byte) (err error) {
return nil
}
func Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) {
return mapper.Mmap(fd, offset, length, prot, flags)
}
func Munmap(b []byte) (err error) {
return mapper.Munmap(b)
}
func Read(fd int, p []byte) (n int, err error) {
n, err = read(fd, p)
if raceenabled {

View File

@ -165,6 +165,7 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys CreateFile(name *uint16, access uint32, mode uint32, sa *SecurityAttributes, createmode uint32, attrs uint32, templatefile Handle) (handle Handle, err error) [failretval==InvalidHandle] = CreateFileW
//sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle Handle, err error) [failretval==InvalidHandle] = CreateNamedPipeW
//sys ConnectNamedPipe(pipe Handle, overlapped *Overlapped) (err error)
//sys DisconnectNamedPipe(pipe Handle) (err error)
//sys GetNamedPipeInfo(pipe Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error)
//sys GetNamedPipeHandleState(pipe Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys SetNamedPipeHandleState(pipe Handle, state *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32) (err error) = SetNamedPipeHandleState
@ -348,8 +349,19 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys SetProcessPriorityBoost(process Handle, disable bool) (err error) = kernel32.SetProcessPriorityBoost
//sys GetProcessWorkingSetSizeEx(hProcess Handle, lpMinimumWorkingSetSize *uintptr, lpMaximumWorkingSetSize *uintptr, flags *uint32)
//sys SetProcessWorkingSetSizeEx(hProcess Handle, dwMinimumWorkingSetSize uintptr, dwMaximumWorkingSetSize uintptr, flags uint32) (err error)
//sys ClearCommBreak(handle Handle) (err error)
//sys ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error)
//sys EscapeCommFunction(handle Handle, dwFunc uint32) (err error)
//sys GetCommState(handle Handle, lpDCB *DCB) (err error)
//sys GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error)
//sys GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys PurgeComm(handle Handle, dwFlags uint32) (err error)
//sys SetCommBreak(handle Handle) (err error)
//sys SetCommMask(handle Handle, dwEvtMask uint32) (err error)
//sys SetCommState(handle Handle, lpDCB *DCB) (err error)
//sys SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error)
//sys WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error)
//sys GetActiveProcessorCount(groupNumber uint16) (ret uint32)
//sys GetMaximumProcessorCount(groupNumber uint16) (ret uint32)
//sys EnumWindows(enumFunc uintptr, param unsafe.Pointer) (err error) = user32.EnumWindows
@ -1834,3 +1846,73 @@ func ResizePseudoConsole(pconsole Handle, size Coord) error {
// accept arguments that can be casted to uintptr, and Coord can't.
return resizePseudoConsole(pconsole, *((*uint32)(unsafe.Pointer(&size))))
}
// DCB constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-dcb.
const (
CBR_110 = 110
CBR_300 = 300
CBR_600 = 600
CBR_1200 = 1200
CBR_2400 = 2400
CBR_4800 = 4800
CBR_9600 = 9600
CBR_14400 = 14400
CBR_19200 = 19200
CBR_38400 = 38400
CBR_57600 = 57600
CBR_115200 = 115200
CBR_128000 = 128000
CBR_256000 = 256000
DTR_CONTROL_DISABLE = 0x00000000
DTR_CONTROL_ENABLE = 0x00000010
DTR_CONTROL_HANDSHAKE = 0x00000020
RTS_CONTROL_DISABLE = 0x00000000
RTS_CONTROL_ENABLE = 0x00001000
RTS_CONTROL_HANDSHAKE = 0x00002000
RTS_CONTROL_TOGGLE = 0x00003000
NOPARITY = 0
ODDPARITY = 1
EVENPARITY = 2
MARKPARITY = 3
SPACEPARITY = 4
ONESTOPBIT = 0
ONE5STOPBITS = 1
TWOSTOPBITS = 2
)
// EscapeCommFunction constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-escapecommfunction.
const (
SETXOFF = 1
SETXON = 2
SETRTS = 3
CLRRTS = 4
SETDTR = 5
CLRDTR = 6
SETBREAK = 8
CLRBREAK = 9
)
// PurgeComm constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-purgecomm.
const (
PURGE_TXABORT = 0x0001
PURGE_RXABORT = 0x0002
PURGE_TXCLEAR = 0x0004
PURGE_RXCLEAR = 0x0008
)
// SetCommMask constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-setcommmask.
const (
EV_RXCHAR = 0x0001
EV_RXFLAG = 0x0002
EV_TXEMPTY = 0x0004
EV_CTS = 0x0008
EV_DSR = 0x0010
EV_RLSD = 0x0020
EV_BREAK = 0x0040
EV_ERR = 0x0080
EV_RING = 0x0100
)

View File

@ -3380,3 +3380,27 @@ type BLOB struct {
Size uint32
BlobData *byte
}
type ComStat struct {
Flags uint32
CBInQue uint32
CBOutQue uint32
}
type DCB struct {
DCBlength uint32
BaudRate uint32
Flags uint32
wReserved uint16
XonLim uint16
XoffLim uint16
ByteSize uint8
Parity uint8
StopBits uint8
XonChar byte
XoffChar byte
ErrorChar byte
EofChar byte
EvtChar byte
wReserved1 uint16
}

View File

@ -188,6 +188,8 @@ var (
procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procClearCommBreak = modkernel32.NewProc("ClearCommBreak")
procClearCommError = modkernel32.NewProc("ClearCommError")
procCloseHandle = modkernel32.NewProc("CloseHandle")
procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
@ -212,7 +214,9 @@ var (
procDeleteProcThreadAttributeList = modkernel32.NewProc("DeleteProcThreadAttributeList")
procDeleteVolumeMountPointW = modkernel32.NewProc("DeleteVolumeMountPointW")
procDeviceIoControl = modkernel32.NewProc("DeviceIoControl")
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
procDuplicateHandle = modkernel32.NewProc("DuplicateHandle")
procEscapeCommFunction = modkernel32.NewProc("EscapeCommFunction")
procExitProcess = modkernel32.NewProc("ExitProcess")
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
procFindClose = modkernel32.NewProc("FindClose")
@ -236,6 +240,8 @@ var (
procGenerateConsoleCtrlEvent = modkernel32.NewProc("GenerateConsoleCtrlEvent")
procGetACP = modkernel32.NewProc("GetACP")
procGetActiveProcessorCount = modkernel32.NewProc("GetActiveProcessorCount")
procGetCommModemStatus = modkernel32.NewProc("GetCommModemStatus")
procGetCommState = modkernel32.NewProc("GetCommState")
procGetCommTimeouts = modkernel32.NewProc("GetCommTimeouts")
procGetCommandLineW = modkernel32.NewProc("GetCommandLineW")
procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW")
@ -322,6 +328,7 @@ var (
procProcess32NextW = modkernel32.NewProc("Process32NextW")
procProcessIdToSessionId = modkernel32.NewProc("ProcessIdToSessionId")
procPulseEvent = modkernel32.NewProc("PulseEvent")
procPurgeComm = modkernel32.NewProc("PurgeComm")
procQueryDosDeviceW = modkernel32.NewProc("QueryDosDeviceW")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject")
@ -335,6 +342,9 @@ var (
procResetEvent = modkernel32.NewProc("ResetEvent")
procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole")
procResumeThread = modkernel32.NewProc("ResumeThread")
procSetCommBreak = modkernel32.NewProc("SetCommBreak")
procSetCommMask = modkernel32.NewProc("SetCommMask")
procSetCommState = modkernel32.NewProc("SetCommState")
procSetCommTimeouts = modkernel32.NewProc("SetCommTimeouts")
procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition")
procSetConsoleMode = modkernel32.NewProc("SetConsoleMode")
@ -342,7 +352,6 @@ var (
procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories")
procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW")
procSetEndOfFile = modkernel32.NewProc("SetEndOfFile")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW")
procSetErrorMode = modkernel32.NewProc("SetErrorMode")
procSetEvent = modkernel32.NewProc("SetEvent")
@ -351,6 +360,7 @@ var (
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procSetFilePointer = modkernel32.NewProc("SetFilePointer")
procSetFileTime = modkernel32.NewProc("SetFileTime")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetHandleInformation = modkernel32.NewProc("SetHandleInformation")
procSetInformationJobObject = modkernel32.NewProc("SetInformationJobObject")
procSetNamedPipeHandleState = modkernel32.NewProc("SetNamedPipeHandleState")
@ -361,6 +371,7 @@ var (
procSetStdHandle = modkernel32.NewProc("SetStdHandle")
procSetVolumeLabelW = modkernel32.NewProc("SetVolumeLabelW")
procSetVolumeMountPointW = modkernel32.NewProc("SetVolumeMountPointW")
procSetupComm = modkernel32.NewProc("SetupComm")
procSizeofResource = modkernel32.NewProc("SizeofResource")
procSleepEx = modkernel32.NewProc("SleepEx")
procTerminateJobObject = modkernel32.NewProc("TerminateJobObject")
@ -379,6 +390,7 @@ var (
procVirtualQueryEx = modkernel32.NewProc("VirtualQueryEx")
procVirtualUnlock = modkernel32.NewProc("VirtualUnlock")
procWTSGetActiveConsoleSessionId = modkernel32.NewProc("WTSGetActiveConsoleSessionId")
procWaitCommEvent = modkernel32.NewProc("WaitCommEvent")
procWaitForMultipleObjects = modkernel32.NewProc("WaitForMultipleObjects")
procWaitForSingleObject = modkernel32.NewProc("WaitForSingleObject")
procWriteConsoleW = modkernel32.NewProc("WriteConsoleW")
@ -1641,6 +1653,22 @@ func CancelIoEx(s Handle, o *Overlapped) (err error) {
return
}
func ClearCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommError.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpErrors)), uintptr(unsafe.Pointer(lpStat)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func CloseHandle(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procCloseHandle.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
@ -1845,6 +1873,14 @@ func DeviceIoControl(handle Handle, ioControlCode uint32, inBuffer *byte, inBuff
return
}
func DisconnectNamedPipe(pipe Handle) (err error) {
r1, _, e1 := syscall.Syscall(procDisconnectNamedPipe.Addr(), 1, uintptr(pipe), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetProcessHandle Handle, lpTargetHandle *Handle, dwDesiredAccess uint32, bInheritHandle bool, dwOptions uint32) (err error) {
var _p0 uint32
if bInheritHandle {
@ -1857,6 +1893,14 @@ func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetP
return
}
func EscapeCommFunction(handle Handle, dwFunc uint32) (err error) {
r1, _, e1 := syscall.Syscall(procEscapeCommFunction.Addr(), 2, uintptr(handle), uintptr(dwFunc), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ExitProcess(exitcode uint32) {
syscall.Syscall(procExitProcess.Addr(), 1, uintptr(exitcode), 0, 0)
return
@ -2058,6 +2102,22 @@ func GetActiveProcessorCount(groupNumber uint16) (ret uint32) {
return
}
func GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommModemStatus.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpModemStat)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 {
@ -2810,6 +2870,14 @@ func PulseEvent(event Handle) (err error) {
return
}
func PurgeComm(handle Handle, dwFlags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procPurgeComm.Addr(), 2, uintptr(handle), uintptr(dwFlags), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func QueryDosDevice(deviceName *uint16, targetPath *uint16, max uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall(procQueryDosDeviceW.Addr(), 3, uintptr(unsafe.Pointer(deviceName)), uintptr(unsafe.Pointer(targetPath)), uintptr(max))
n = uint32(r0)
@ -2924,6 +2992,30 @@ func ResumeThread(thread Handle) (ret uint32, err error) {
return
}
func SetCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommMask(handle Handle, dwEvtMask uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommMask.Addr(), 2, uintptr(handle), uintptr(dwEvtMask), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 {
@ -2989,14 +3081,6 @@ func SetEndOfFile(handle Handle) (err error) {
return
}
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetEnvironmentVariable(name *uint16, value *uint16) (err error) {
r1, _, e1 := syscall.Syscall(procSetEnvironmentVariableW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(value)), 0)
if r1 == 0 {
@ -3060,6 +3144,14 @@ func SetFileTime(handle Handle, ctime *Filetime, atime *Filetime, wtime *Filetim
return
}
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetHandleInformation(handle Handle, mask uint32, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetHandleInformation.Addr(), 3, uintptr(handle), uintptr(mask), uintptr(flags))
if r1 == 0 {
@ -3145,6 +3237,14 @@ func SetVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16) (err erro
return
}
func SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetupComm.Addr(), 3, uintptr(handle), uintptr(dwInQueue), uintptr(dwOutQueue))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SizeofResource(module Handle, resInfo Handle) (size uint32, err error) {
r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0)
size = uint32(r0)
@ -3291,6 +3391,14 @@ func WTSGetActiveConsoleSessionId() (sessionID uint32) {
return
}
func WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procWaitCommEvent.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpEvtMask)), uintptr(unsafe.Pointer(lpOverlapped)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func waitForMultipleObjects(count uint32, handles uintptr, waitAll bool, waitMilliseconds uint32) (event uint32, err error) {
var _p0 uint32
if waitAll {

12
vendor/modules.txt vendored
View File

@ -711,7 +711,7 @@ go.uber.org/zap/internal/pool
go.uber.org/zap/internal/stacktrace
go.uber.org/zap/zapcore
go.uber.org/zap/zapgrpc
# golang.org/x/crypto v0.21.0
# golang.org/x/crypto v0.22.0
## explicit; go 1.18
golang.org/x/crypto/argon2
golang.org/x/crypto/blake2b
@ -737,7 +737,7 @@ golang.org/x/crypto/ssh/internal/bcrypt_pbkdf
golang.org/x/exp/constraints
golang.org/x/exp/maps
golang.org/x/exp/slices
# golang.org/x/net v0.22.0
# golang.org/x/net v0.24.0
## explicit; go 1.18
golang.org/x/net/context
golang.org/x/net/html
@ -759,14 +759,14 @@ golang.org/x/oauth2/internal
# golang.org/x/sync v0.6.0
## explicit; go 1.18
golang.org/x/sync/singleflight
# golang.org/x/sys v0.18.0
# golang.org/x/sys v0.19.0
## explicit; go 1.18
golang.org/x/sys/cpu
golang.org/x/sys/plan9
golang.org/x/sys/unix
golang.org/x/sys/windows
golang.org/x/sys/windows/registry
# golang.org/x/term v0.18.0
# golang.org/x/term v0.19.0
## explicit; go 1.18
golang.org/x/term
# golang.org/x/text v0.14.0
@ -998,7 +998,7 @@ k8s.io/api/scheduling/v1beta1
k8s.io/api/storage/v1
k8s.io/api/storage/v1alpha1
k8s.io/api/storage/v1beta1
# k8s.io/apiextensions-apiserver v0.29.0 => k8s.io/apiextensions-apiserver v0.29.3
# k8s.io/apiextensions-apiserver v0.29.2 => k8s.io/apiextensions-apiserver v0.29.3
## explicit; go 1.21
k8s.io/apiextensions-apiserver/pkg/apis/apiextensions
k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1
@ -1682,7 +1682,7 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client/metrics
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client
# sigs.k8s.io/controller-runtime v0.17.2
# sigs.k8s.io/controller-runtime v0.17.3
## explicit; go 1.21
sigs.k8s.io/controller-runtime/pkg/cache
sigs.k8s.io/controller-runtime/pkg/cache/internal

View File

@ -20,6 +20,7 @@ import (
"context"
"fmt"
"net/http"
"sort"
"time"
"golang.org/x/exp/maps"
@ -421,7 +422,12 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
for namespace, cfg := range opts.DefaultNamespaces {
cfg = defaultConfig(cfg, optionDefaultsToConfig(&opts))
if namespace == metav1.NamespaceAll {
cfg.FieldSelector = fields.AndSelectors(appendIfNotNil(namespaceAllSelector(maps.Keys(opts.DefaultNamespaces)), cfg.FieldSelector)...)
cfg.FieldSelector = fields.AndSelectors(
appendIfNotNil(
namespaceAllSelector(maps.Keys(opts.DefaultNamespaces)),
cfg.FieldSelector,
)...,
)
}
opts.DefaultNamespaces[namespace] = cfg
}
@ -435,7 +441,12 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
return opts, fmt.Errorf("type %T is not namespaced, but its ByObject.Namespaces setting is not nil", obj)
}
// Default the namespace-level configs first, because they need to use the undefaulted type-level config.
if isNamespaced && byObject.Namespaces == nil {
byObject.Namespaces = maps.Clone(opts.DefaultNamespaces)
}
// Default the namespace-level configs first, because they need to use the undefaulted type-level config
// to be able to potentially fall through to settings from DefaultNamespaces.
for namespace, config := range byObject.Namespaces {
// 1. Default from the undefaulted type-level config
config = defaultConfig(config, byObjectToConfig(byObject))
@ -461,14 +472,14 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
byObject.Namespaces[namespace] = config
}
defaultedConfig := defaultConfig(byObjectToConfig(byObject), optionDefaultsToConfig(&opts))
byObject.Label = defaultedConfig.LabelSelector
byObject.Field = defaultedConfig.FieldSelector
byObject.Transform = defaultedConfig.Transform
byObject.UnsafeDisableDeepCopy = defaultedConfig.UnsafeDisableDeepCopy
if isNamespaced && byObject.Namespaces == nil {
byObject.Namespaces = opts.DefaultNamespaces
// Only default ByObject iself if it isn't namespaced or has no namespaces configured, as only
// then any of this will be honored.
if !isNamespaced || len(byObject.Namespaces) == 0 {
defaultedConfig := defaultConfig(byObjectToConfig(byObject), optionDefaultsToConfig(&opts))
byObject.Label = defaultedConfig.LabelSelector
byObject.Field = defaultedConfig.FieldSelector
byObject.Transform = defaultedConfig.Transform
byObject.UnsafeDisableDeepCopy = defaultedConfig.UnsafeDisableDeepCopy
}
opts.ByObject[obj] = byObject
@ -498,20 +509,21 @@ func defaultConfig(toDefault, defaultFrom Config) Config {
return toDefault
}
func namespaceAllSelector(namespaces []string) fields.Selector {
func namespaceAllSelector(namespaces []string) []fields.Selector {
selectors := make([]fields.Selector, 0, len(namespaces)-1)
sort.Strings(namespaces)
for _, namespace := range namespaces {
if namespace != metav1.NamespaceAll {
selectors = append(selectors, fields.OneTermNotEqualSelector("metadata.namespace", namespace))
}
}
return fields.AndSelectors(selectors...)
return selectors
}
func appendIfNotNil[T comparable](a, b T) []T {
func appendIfNotNil[T comparable](a []T, b T) []T {
if b != *new(T) {
return []T{a, b}
return append(a, b)
}
return []T{a}
return a
}

Some files were not shown because too many files have changed in this diff Show More