Merge pull request #272 from red-hat-storage/sync_us--devel

Syncing latest changes from upstream devel for ceph-csi
This commit is contained in:
openshift-merge-bot[bot] 2024-03-13 08:09:24 +00:00 committed by GitHub
commit d1cc48391b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 770 additions and 496 deletions

View File

@ -166,6 +166,10 @@ cephcsi: check-env
e2e.test: check-env e2e.test: check-env
go test $(GO_TAGS) -mod=vendor -c ./e2e go test $(GO_TAGS) -mod=vendor -c ./e2e
.PHONY: rbd-group-snapshot
rbd-group-snapshot:
go build -o _output/rbd-group-snapshot ./tools/rbd-group-snapshot
# #
# Update the generated deploy/ files when the template changed. This requires # Update the generated deploy/ files when the template changed. This requires
# running 'go mod vendor' so update the API files under the vendor/ directory. # running 'go mod vendor' so update the API files under the vendor/ directory.

View File

@ -4,13 +4,13 @@ go 1.18
require ( require (
github.com/google/go-github v17.0.0+incompatible github.com/google/go-github v17.0.0+incompatible
golang.org/x/oauth2 v0.17.0 golang.org/x/oauth2 v0.18.0
) )
require ( require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.22.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
) )

View File

@ -11,10 +11,10 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View File

@ -7,10 +7,10 @@ github.com/google/go-github/github
# github.com/google/go-querystring v1.1.0 # github.com/google/go-querystring v1.1.0
## explicit; go 1.10 ## explicit; go 1.10
github.com/google/go-querystring/query github.com/google/go-querystring/query
# golang.org/x/net v0.21.0 # golang.org/x/net v0.22.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/net/context golang.org/x/net/context
# golang.org/x/oauth2 v0.17.0 # golang.org/x/oauth2 v0.18.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/oauth2 golang.org/x/oauth2
golang.org/x/oauth2/internal golang.org/x/oauth2/internal

View File

@ -7,7 +7,7 @@ toolchain go1.21.5
require ( require (
github.com/ghodss/yaml v1.0.0 github.com/ghodss/yaml v1.0.0
github.com/openshift/api v0.0.0-20240115183315-0793e918179d github.com/openshift/api v0.0.0-20240115183315-0793e918179d
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.9.0
k8s.io/api v0.29.2 k8s.io/api v0.29.2
) )

View File

@ -36,8 +36,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@ -28,6 +28,8 @@ var (
uint32Type = reflect.TypeOf(uint32(1)) uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1)) uint64Type = reflect.TypeOf(uint64(1))
uintptrType = reflect.TypeOf(uintptr(1))
float32Type = reflect.TypeOf(float32(1)) float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1)) float64Type = reflect.TypeOf(float64(1))
@ -308,11 +310,11 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Struct: case reflect.Struct:
{ {
// All structs enter here. We're not interested in most types. // All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) { if !obj1Value.CanConvert(timeType) {
break break
} }
// time.Time can compared! // time.Time can be compared!
timeObj1, ok := obj1.(time.Time) timeObj1, ok := obj1.(time.Time)
if !ok { if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
@ -328,7 +330,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Slice: case reflect.Slice:
{ {
// We only care about the []byte type. // We only care about the []byte type.
if !canConvert(obj1Value, bytesType) { if !obj1Value.CanConvert(bytesType) {
break break
} }
@ -345,6 +347,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
} }
case reflect.Uintptr:
{
uintptrObj1, ok := obj1.(uintptr)
if !ok {
uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
}
uintptrObj2, ok := obj2.(uintptr)
if !ok {
uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
}
if uintptrObj1 > uintptrObj2 {
return compareGreater, true
}
if uintptrObj1 == uintptrObj2 {
return compareEqual, true
}
if uintptrObj1 < uintptrObj2 {
return compareLess, true
}
}
} }
return compareEqual, false return compareEqual, false

View File

@ -1,16 +0,0 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@ -1,16 +0,0 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -107,7 +104,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -616,6 +613,16 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...) return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
} }
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil. // NotNilf asserts that the specified object is not nil.
// //
// assert.NotNilf(t, err, "error message %s", "formatted") // assert.NotNilf(t, err, "error message %s", "formatted")
@ -660,10 +667,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -747,10 +756,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...) return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -189,7 +186,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...) return EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
@ -200,7 +197,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...) return EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1221,6 +1218,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
return NotErrorIsf(a.t, err, target, msg, args...) return NotErrorIsf(a.t, err, target, msg, args...)
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// a.NotNil(err) // a.NotNil(err)
@ -1309,10 +1326,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...) return NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1320,10 +1339,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...) return NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1483,10 +1504,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...) return Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1494,10 +1516,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...) return Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -19,7 +19,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib" "github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -110,7 +110,12 @@ func copyExportedFields(expected interface{}) interface{} {
return result.Interface() return result.Interface()
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) var result reflect.Value
if expectedKind == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
}
for i := 0; i < expectedValue.Len(); i++ { for i := 0; i < expectedValue.Len(); i++ {
index := expectedValue.Index(i) index := expectedValue.Index(i)
if isNil(index) { if isNil(index) {
@ -140,6 +145,8 @@ func copyExportedFields(expected interface{}) interface{} {
// structures. // structures.
// //
// This function does no assertion of any kind. // This function does no assertion of any kind.
//
// Deprecated: Use [EqualExportedValues] instead.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
expectedCleaned := copyExportedFields(expected) expectedCleaned := copyExportedFields(expected)
actualCleaned := copyExportedFields(actual) actualCleaned := copyExportedFields(actual)
@ -153,17 +160,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
return true return true
} }
actualType := reflect.TypeOf(actual) expectedValue := reflect.ValueOf(expected)
if actualType == nil { actualValue := reflect.ValueOf(actual)
if !expectedValue.IsValid() || !actualValue.IsValid() {
return false return false
} }
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { expectedType := expectedValue.Type()
// Attempt comparison after type conversion actualType := actualValue.Type()
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) if !expectedType.ConvertibleTo(actualType) {
return false
} }
return false if !isNumericType(expectedType) || !isNumericType(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(
expectedValue.Convert(actualType).Interface(), actual,
)
}
// If BOTH values are numeric, there are chances of false positives due
// to overflow or underflow. So, we need to make sure to always convert
// the smaller type to a larger type before comparing.
if expectedType.Size() >= actualType.Size() {
return actualValue.Convert(expectedType).Interface() == expected
}
return expectedValue.Convert(actualType).Interface() == actual
}
// isNumericType returns true if the type is one of:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
// float32, float64, complex64, complex128
func isNumericType(t reflect.Type) bool {
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
} }
/* CallerInfo is necessary because the assert functions use the testing object /* CallerInfo is necessary because the assert functions use the testing object
@ -266,7 +296,7 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
// Aligns the provided message so that all lines after the first line start at the same location as the first line. // Aligns the provided message so that all lines after the first line start at the same location as the first line.
// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). // Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab).
// The longestLabelLen parameter specifies the length of the longest label in the output (required becaues this is the // The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the
// basis on which the alignment occurs). // basis on which the alignment occurs).
func indentMessageLines(message string, longestLabelLen int) string { func indentMessageLines(message string, longestLabelLen int) string {
outBuf := new(bytes.Buffer) outBuf := new(bytes.Buffer)
@ -382,6 +412,25 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg
return true return true
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
interfaceType := reflect.TypeOf(interfaceObject).Elem()
if object == nil {
return Fail(t, fmt.Sprintf("Cannot check if nil does not implement %v", interfaceType), msgAndArgs...)
}
if reflect.TypeOf(object).Implements(interfaceType) {
return Fail(t, fmt.Sprintf("%T implements %v", object, interfaceType), msgAndArgs...)
}
return true
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -496,7 +545,7 @@ func samePointers(first, second interface{}) bool {
// representations appropriate to be presented to the user. // representations appropriate to be presented to the user.
// //
// If the values are not of like type, the returned strings will be prefixed // If the values are not of like type, the returned strings will be prefixed
// with the type name, and the value will be enclosed in parenthesis similar // with the type name, and the value will be enclosed in parentheses similar
// to a type conversion in the Go grammar. // to a type conversion in the Go grammar.
func formatUnequalValues(expected, actual interface{}) (e string, a string) { func formatUnequalValues(expected, actual interface{}) (e string, a string) {
if reflect.TypeOf(expected) != reflect.TypeOf(actual) { if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
@ -523,7 +572,7 @@ func truncatingFormat(data interface{}) string {
return value return value
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
@ -566,12 +615,19 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
} }
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct { if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
} }
if bType.Kind() != reflect.Struct { if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
} }
expected = copyExportedFields(expected) expected = copyExportedFields(expected)
@ -620,17 +676,6 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, "Expected value not to be nil.", msgAndArgs...) return Fail(t, "Expected value not to be nil.", msgAndArgs...)
} }
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing. // isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool { func isNil(object interface{}) bool {
if object == nil { if object == nil {
@ -638,16 +683,13 @@ func isNil(object interface{}) bool {
} }
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind() switch value.Kind() {
isNilableKind := containsKind( case
[]reflect.Kind{ reflect.Chan, reflect.Func,
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
reflect.Ptr, reflect.Slice, reflect.UnsafePointer},
kind)
if isNilableKind && value.IsNil() { return value.IsNil()
return true
} }
return false return false
@ -731,16 +773,14 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
} }
// getLen try to get length of object. // getLen tries to get the length of an object.
// return (false, 0) if impossible. // It returns (0, false) if impossible.
func getLen(x interface{}) (ok bool, length int) { func getLen(x interface{}) (length int, ok bool) {
v := reflect.ValueOf(x) v := reflect.ValueOf(x)
defer func() { defer func() {
if e := recover(); e != nil { ok = recover() == nil
ok = false
}
}() }()
return true, v.Len() return v.Len(), true
} }
// Len asserts that the specified object has specific length. // Len asserts that the specified object has specific length.
@ -751,13 +791,13 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
ok, l := getLen(object) l, ok := getLen(object)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" could not be applied builtin len()", object), msgAndArgs...)
} }
if l != length { if l != length {
return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" should have %d item(s), but has %d", object, length, l), msgAndArgs...)
} }
return true return true
} }
@ -919,10 +959,11 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -975,10 +1016,12 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true return true
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1439,7 +1482,7 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
h.Helper() h.Helper()
} }
if math.IsNaN(epsilon) { if math.IsNaN(epsilon) {
return Fail(t, "epsilon must not be NaN") return Fail(t, "epsilon must not be NaN", msgAndArgs...)
} }
actualEpsilon, err := calcRelativeError(expected, actual) actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil { if err != nil {
@ -1458,19 +1501,26 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || if expected == nil || actual == nil {
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, "Parameters must be slice", msgAndArgs...) return Fail(t, "Parameters must be slice", msgAndArgs...)
} }
actualSlice := reflect.ValueOf(actual)
expectedSlice := reflect.ValueOf(expected) expectedSlice := reflect.ValueOf(expected)
actualSlice := reflect.ValueOf(actual)
for i := 0; i < actualSlice.Len(); i++ { if expectedSlice.Type().Kind() != reflect.Slice {
result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon) return Fail(t, "Expected value must be slice", msgAndArgs...)
if !result { }
return result
expectedLen := expectedSlice.Len()
if !IsType(t, expected, actual) || !Len(t, actual, expectedLen) {
return false
}
for i := 0; i < expectedLen; i++ {
if !InEpsilon(t, expectedSlice.Index(i).Interface(), actualSlice.Index(i).Interface(), epsilon, "at index %d", i) {
return false
} }
} }
@ -1870,23 +1920,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
} }
// FailNow panics. // FailNow panics.
func (c *CollectT) FailNow() { func (*CollectT) FailNow() {
panic("Assertion failed") panic("Assertion failed")
} }
// Reset clears the collected errors. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Reset() { func (*CollectT) Reset() {
c.errors = nil panic("Reset() is deprecated")
} }
// Copy copies the collected errors to the supplied t. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Copy(t TestingT) { func (*CollectT) Copy(TestingT) {
if tt, ok := t.(tHelper); ok { panic("Copy() is deprecated")
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
} }
// EventuallyWithT asserts that given condition will be met in waitFor time, // EventuallyWithT asserts that given condition will be met in waitFor time,
@ -1912,8 +1957,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper() h.Helper()
} }
collect := new(CollectT) var lastFinishedTickErrs []error
ch := make(chan bool, 1) ch := make(chan []error, 1)
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -1924,19 +1969,25 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; { for tick := ticker.C; ; {
select { select {
case <-timer.C: case <-timer.C:
collect.Copy(t) for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...) return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick: case <-tick:
tick = nil tick = nil
collect.Reset()
go func() { go func() {
collect := new(CollectT)
defer func() {
ch <- collect.errors
}()
condition(collect) condition(collect)
ch <- len(collect.errors) == 0
}() }()
case v := <-ch: case errs := <-ch:
if v { if len(errs) == 0 {
return true return true
} }
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C tick = ticker.C
} }
} }

View File

@ -12,7 +12,7 @@ import (
// an error if building a new request fails. // an error if building a new request fails.
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil) req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -32,12 +32,12 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
if !isSuccessCode { if !isSuccessCode {
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isSuccessCode return isSuccessCode
@ -54,12 +54,12 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
if !isRedirectCode { if !isRedirectCode {
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isRedirectCode return isRedirectCode
@ -76,12 +76,12 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isErrorCode := code >= http.StatusBadRequest isErrorCode := code >= http.StatusBadRequest
if !isErrorCode { if !isErrorCode {
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isErrorCode return isErrorCode
@ -98,12 +98,12 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
successful := code == statuscode successful := code == statuscode
if !successful { if !successful {
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...)
} }
return successful return successful
@ -113,7 +113,10 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
// empty string if building a new request fails. // empty string if building a new request fails.
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) if len(values) > 0 {
url += "?" + values.Encode()
}
req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return "" return ""
} }
@ -135,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if !contains { if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return contains return contains
@ -155,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if contains { if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return !contains return !contains

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package require package require
@ -235,7 +232,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
t.FailNow() t.FailNow()
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
@ -249,7 +246,7 @@ func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArg
t.FailNow() t.FailNow()
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -1546,6 +1543,32 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
t.FailNow() t.FailNow()
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplements(t, interfaceObject, object, msgAndArgs...) {
return
}
t.FailNow()
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplementsf(t, interfaceObject, object, msg, args...) {
return
}
t.FailNow()
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// assert.NotNil(t, err) // assert.NotNil(t, err)
@ -1658,10 +1681,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
t.FailNow() t.FailNow()
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1672,10 +1697,12 @@ func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...i
t.FailNow() t.FailNow()
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1880,10 +1907,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
t.FailNow() t.FailNow()
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1894,10 +1922,11 @@ func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...inte
t.FailNow() t.FailNow()
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package require package require
@ -190,7 +187,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
EqualExportedValuesf(a.t, expected, actual, msg, args...) EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
@ -201,7 +198,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
EqualValues(a.t, expected, actual, msgAndArgs...) EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1222,6 +1219,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
NotErrorIsf(a.t, err, target, msg, args...) NotErrorIsf(a.t, err, target, msg, args...)
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// a.NotNil(err) // a.NotNil(err)
@ -1310,10 +1327,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
NotSamef(a.t, expected, actual, msg, args...) NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1321,10 +1340,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
NotSubset(a.t, list, subset, msgAndArgs...) NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1484,10 +1505,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
Samef(a.t, expected, actual, msg, args...) Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1495,10 +1517,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
Subset(a.t, list, subset, msgAndArgs...) Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -31,8 +31,8 @@ github.com/openshift/api/security/v1
# github.com/pmezard/go-difflib v1.0.0 # github.com/pmezard/go-difflib v1.0.0
## explicit ## explicit
github.com/pmezard/go-difflib/difflib github.com/pmezard/go-difflib/difflib
# github.com/stretchr/testify v1.8.4 # github.com/stretchr/testify v1.9.0
## explicit; go 1.20 ## explicit; go 1.17
github.com/stretchr/testify/assert github.com/stretchr/testify/assert
github.com/stretchr/testify/require github.com/stretchr/testify/require
# golang.org/x/net v0.19.0 # golang.org/x/net v0.19.0

View File

@ -6,4 +6,6 @@ spec:
attachRequired: false attachRequired: false
podInfoOnMount: false podInfoOnMount: false
fsGroupPolicy: {{ default "File" .Values.CSIDriver.fsGroupPolicy }} fsGroupPolicy: {{ default "File" .Values.CSIDriver.fsGroupPolicy }}
{{- if semverCompare ">= 1.25.x" .Capabilities.KubeVersion.Version }}
seLinuxMount: true seLinuxMount: true
{{- end }}

View File

@ -6,4 +6,6 @@ spec:
attachRequired: true attachRequired: true
podInfoOnMount: false podInfoOnMount: false
fsGroupPolicy: File fsGroupPolicy: File
{{- if semverCompare ">= 1.25.x" .Capabilities.KubeVersion.Version }}
seLinuxMount: true seLinuxMount: true
{{- end }}

4
go.mod
View File

@ -26,7 +26,7 @@ require (
github.com/onsi/gomega v1.31.1 github.com/onsi/gomega v1.31.1
github.com/pkg/xattr v0.4.9 github.com/pkg/xattr v0.4.9
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.21.0
golang.org/x/net v0.22.0 golang.org/x/net v0.22.0
golang.org/x/sys v0.18.0 golang.org/x/sys v0.18.0
@ -42,7 +42,7 @@ require (
k8s.io/klog/v2 v2.120.1 k8s.io/klog/v2 v2.120.1
k8s.io/kubernetes v1.29.2 k8s.io/kubernetes v1.29.2
k8s.io/mount-utils v0.29.2 k8s.io/mount-utils v0.29.2
k8s.io/pod-security-admission v0.0.0 k8s.io/pod-security-admission v0.29.2
k8s.io/utils v0.0.0-20230726121419-3b25d923346b k8s.io/utils v0.0.0-20230726121419-3b25d923346b
sigs.k8s.io/controller-runtime v0.17.2 sigs.k8s.io/controller-runtime v0.17.2
) )

3
go.sum
View File

@ -1603,8 +1603,9 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=

View File

@ -77,7 +77,7 @@ func (cs *ControllerServer) createBackingVolume(
&volOptions.SubVolume, volOptions.ClusterID, cs.ClusterName, cs.SetMetadata) &volOptions.SubVolume, volOptions.ClusterID, cs.ClusterName, cs.SetMetadata)
if sID != nil { if sID != nil {
err = parentVolOpt.CopyEncryptionConfig(volOptions, sID.SnapshotID, vID.VolumeID) err = parentVolOpt.CopyEncryptionConfig(ctx, volOptions, sID.SnapshotID, vID.VolumeID)
if err != nil { if err != nil {
return status.Error(codes.Internal, err.Error()) return status.Error(codes.Internal, err.Error())
} }
@ -86,7 +86,7 @@ func (cs *ControllerServer) createBackingVolume(
} }
if parentVolOpt != nil { if parentVolOpt != nil {
err = parentVolOpt.CopyEncryptionConfig(volOptions, pvID.VolumeID, vID.VolumeID) err = parentVolOpt.CopyEncryptionConfig(ctx, volOptions, pvID.VolumeID, vID.VolumeID)
if err != nil { if err != nil {
return status.Error(codes.Internal, err.Error()) return status.Error(codes.Internal, err.Error())
} }
@ -596,7 +596,7 @@ func (cs *ControllerServer) cleanUpBackingVolume(
// GetSecret enabled KMS the DEKs are stored by // GetSecret enabled KMS the DEKs are stored by
// fscrypt on the volume that is going to be deleted anyway. // fscrypt on the volume that is going to be deleted anyway.
log.DebugLog(ctx, "going to remove DEK for integrated store %q (fscrypt)", volOptions.Encryption.GetID()) log.DebugLog(ctx, "going to remove DEK for integrated store %q (fscrypt)", volOptions.Encryption.GetID())
if err := volOptions.Encryption.RemoveDEK(volID.VolumeID); err != nil { if err := volOptions.Encryption.RemoveDEK(ctx, volID.VolumeID); err != nil {
log.WarningLog(ctx, "failed to clean the passphrase for volume %q (file encryption): %s", log.WarningLog(ctx, "failed to clean the passphrase for volume %q (file encryption): %s",
volOptions.VolID, err) volOptions.VolID, err)
} }
@ -907,7 +907,7 @@ func (cs *ControllerServer) CreateSnapshot(
// Use same encryption KMS than source volume and copy the passphrase. The passphrase becomes // Use same encryption KMS than source volume and copy the passphrase. The passphrase becomes
// available under the snapshot id for CreateVolume to use this snap as a backing volume // available under the snapshot id for CreateVolume to use this snap as a backing volume
snapVolOptions := store.VolumeOptions{} snapVolOptions := store.VolumeOptions{}
err = parentVolOptions.CopyEncryptionConfig(&snapVolOptions, sourceVolID, sID.SnapshotID) err = parentVolOptions.CopyEncryptionConfig(ctx, &snapVolOptions, sourceVolID, sID.SnapshotID)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }

View File

@ -901,7 +901,7 @@ func IsEncrypted(ctx context.Context, volOptions map[string]string) (bool, error
// CopyEncryptionConfig copies passphrases and initializes a fresh // CopyEncryptionConfig copies passphrases and initializes a fresh
// Encryption struct if necessary from (vo, vID) to (cp, cpVID). // Encryption struct if necessary from (vo, vID) to (cp, cpVID).
func (vo *VolumeOptions) CopyEncryptionConfig(cp *VolumeOptions, vID, cpVID string) error { func (vo *VolumeOptions) CopyEncryptionConfig(ctx context.Context, cp *VolumeOptions, vID, cpVID string) error {
var err error var err error
if !vo.IsEncrypted() { if !vo.IsEncrypted() {
@ -916,7 +916,7 @@ func (vo *VolumeOptions) CopyEncryptionConfig(cp *VolumeOptions, vID, cpVID stri
if cp.Encryption == nil { if cp.Encryption == nil {
cp.Encryption, err = util.NewVolumeEncryption(vo.Encryption.GetID(), vo.Encryption.KMS) cp.Encryption, err = util.NewVolumeEncryption(vo.Encryption.GetID(), vo.Encryption.KMS)
if errors.Is(err, util.ErrDEKStoreNeeded) { if errors.Is(err, util.ErrDEKStoreNeeded) {
_, err := vo.Encryption.KMS.GetSecret("") _, err := vo.Encryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) { if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err return err
} }
@ -924,13 +924,13 @@ func (vo *VolumeOptions) CopyEncryptionConfig(cp *VolumeOptions, vID, cpVID stri
} }
if vo.Encryption.KMS.RequiresDEKStore() == kmsapi.DEKStoreIntegrated { if vo.Encryption.KMS.RequiresDEKStore() == kmsapi.DEKStoreIntegrated {
passphrase, err := vo.Encryption.GetCryptoPassphrase(vID) passphrase, err := vo.Encryption.GetCryptoPassphrase(ctx, vID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch passphrase for %q (%+v): %w", return fmt.Errorf("failed to fetch passphrase for %q (%+v): %w",
vID, vo, err) vID, vo, err)
} }
err = cp.Encryption.StoreCryptoPassphrase(cpVID, passphrase) err = cp.Encryption.StoreCryptoPassphrase(ctx, cpVID, passphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to store passphrase for %q (%+v): %w", return fmt.Errorf("failed to store passphrase for %q (%+v): %w",
cpVID, cp, err) cpVID, cp, err)
@ -962,7 +962,7 @@ func (vo *VolumeOptions) ConfigureEncryption(
// store. Since not all "metadata" KMS support // store. Since not all "metadata" KMS support
// GetSecret, test for support here. Postpone any // GetSecret, test for support here. Postpone any
// other error handling // other error handling
_, err := vo.Encryption.KMS.GetSecret("") _, err := vo.Encryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) { if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err return err
} }

View File

@ -183,7 +183,7 @@ func (kms *awsMetadataKMS) getService() (*awsKMS.KMS, error) {
} }
// EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK. // EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK.
func (kms *awsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error) { func (kms *awsMetadataKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
svc, err := kms.getService() svc, err := kms.getService()
if err != nil { if err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err) return "", fmt.Errorf("could not get KMS service: %w", err)
@ -205,7 +205,7 @@ func (kms *awsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error)
} }
// DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK. // DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK.
func (kms *awsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) { func (kms *awsMetadataKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
svc, err := kms.getService() svc, err := kms.getService()
if err != nil { if err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err) return "", fmt.Errorf("could not get KMS service: %w", err)
@ -227,6 +227,6 @@ func (kms *awsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, er
return string(result.Plaintext), nil return string(result.Plaintext), nil
} }
func (kms *awsMetadataKMS) GetSecret(volumeID string) (string, error) { func (kms *awsMetadataKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported return "", ErrGetSecretUnsupported
} }

View File

@ -193,7 +193,7 @@ func (as *awsSTSMetadataKMS) getServiceWithSTS() (*awsKMS.KMS, error) {
} }
// EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK. // EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK.
func (as *awsSTSMetadataKMS) EncryptDEK(_, plainDEK string) (string, error) { func (as *awsSTSMetadataKMS) EncryptDEK(ctx context.Context, _, plainDEK string) (string, error) {
svc, err := as.getServiceWithSTS() svc, err := as.getServiceWithSTS()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get KMS service: %w", err) return "", fmt.Errorf("failed to get KMS service: %w", err)
@ -213,7 +213,7 @@ func (as *awsSTSMetadataKMS) EncryptDEK(_, plainDEK string) (string, error) {
} }
// DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK. // DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK.
func (as *awsSTSMetadataKMS) DecryptDEK(_, encryptedDEK string) (string, error) { func (as *awsSTSMetadataKMS) DecryptDEK(ctx context.Context, _, encryptedDEK string) (string, error) {
svc, err := as.getServiceWithSTS() svc, err := as.getServiceWithSTS()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get KMS service: %w", err) return "", fmt.Errorf("failed to get KMS service: %w", err)

View File

@ -204,14 +204,14 @@ func (kms *keyProtectKMS) getService() error {
} }
// EncryptDEK uses the KeyProtect KMS and the configured CRK to encrypt the DEK. // EncryptDEK uses the KeyProtect KMS and the configured CRK to encrypt the DEK.
func (kms *keyProtectKMS) EncryptDEK(volumeID, plainDEK string) (string, error) { func (kms *keyProtectKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
if err := kms.getService(); err != nil { if err := kms.getService(); err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err) return "", fmt.Errorf("could not get KMS service: %w", err)
} }
dekByteSlice := []byte(plainDEK) dekByteSlice := []byte(plainDEK)
aadVolID := []string{volumeID} aadVolID := []string{volumeID}
result, err := kms.client.Wrap(context.TODO(), kms.customerRootKey, dekByteSlice, &aadVolID) result, err := kms.client.Wrap(ctx, kms.customerRootKey, dekByteSlice, &aadVolID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to wrap the DEK: %w", err) return "", fmt.Errorf("failed to wrap the DEK: %w", err)
} }
@ -223,7 +223,7 @@ func (kms *keyProtectKMS) EncryptDEK(volumeID, plainDEK string) (string, error)
} }
// DecryptDEK uses the Key protect KMS and the configured CRK to decrypt the DEK. // DecryptDEK uses the Key protect KMS and the configured CRK to decrypt the DEK.
func (kms *keyProtectKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) { func (kms *keyProtectKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
if err := kms.getService(); err != nil { if err := kms.getService(); err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err) return "", fmt.Errorf("could not get KMS service: %w", err)
} }
@ -235,7 +235,7 @@ func (kms *keyProtectKMS) DecryptDEK(volumeID, encryptedDEK string) (string, err
} }
aadVolID := []string{volumeID} aadVolID := []string{volumeID}
result, err := kms.client.Unwrap(context.TODO(), kms.customerRootKey, ciphertextBlob, &aadVolID) result, err := kms.client.Unwrap(ctx, kms.customerRootKey, ciphertextBlob, &aadVolID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unwrap the DEK: %w", err) return "", fmt.Errorf("failed to unwrap the DEK: %w", err)
} }
@ -243,6 +243,6 @@ func (kms *keyProtectKMS) DecryptDEK(volumeID, encryptedDEK string) (string, err
return string(result), nil return string(result), nil
} }
func (kms *keyProtectKMS) GetSecret(volumeID string) (string, error) { func (kms *keyProtectKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported return "", ErrGetSecretUnsupported
} }

View File

@ -180,7 +180,7 @@ func initKMIPKMS(args ProviderInitArgs) (EncryptionKMS, error) {
} }
// EncryptDEK uses the KMIP encrypt operation to encrypt the DEK. // EncryptDEK uses the KMIP encrypt operation to encrypt the DEK.
func (kms *kmipKMS) EncryptDEK(_, plainDEK string) (string, error) { func (kms *kmipKMS) EncryptDEK(ctx context.Context, _, plainDEK string) (string, error) {
conn, err := kms.connect() conn, err := kms.connect()
if err != nil { if err != nil {
return "", err return "", err
@ -236,7 +236,7 @@ func (kms *kmipKMS) EncryptDEK(_, plainDEK string) (string, error) {
} }
// DecryptDEK uses the KMIP decrypt operation to decrypt the DEK. // DecryptDEK uses the KMIP decrypt operation to decrypt the DEK.
func (kms *kmipKMS) DecryptDEK(_, encryptedDEK string) (string, error) { func (kms *kmipKMS) DecryptDEK(ctx context.Context, _, encryptedDEK string) (string, error) {
conn, err := kms.connect() conn, err := kms.connect()
if err != nil { if err != nil {
return "", err return "", err
@ -500,7 +500,7 @@ func (kms *kmipKMS) verifyResponse(
return &batchItem, nil return &batchItem, nil
} }
func (kms *kmipKMS) GetSecret(volumeID string) (string, error) { func (kms *kmipKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported return "", ErrGetSecretUnsupported
} }

View File

@ -331,18 +331,18 @@ type EncryptionKMS interface {
// EncryptDEK provides a way for a KMS to encrypt a DEK. In case the // EncryptDEK provides a way for a KMS to encrypt a DEK. In case the
// encryption is done transparently inside the KMS service, the // encryption is done transparently inside the KMS service, the
// function can return an unencrypted value. // function can return an unencrypted value.
EncryptDEK(volumeID, plainDEK string) (string, error) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error)
// DecryptDEK provides a way for a KMS to decrypt a DEK. In case the // DecryptDEK provides a way for a KMS to decrypt a DEK. In case the
// encryption is done transparently inside the KMS service, the // encryption is done transparently inside the KMS service, the
// function does not need to do anything except return the encyptedDEK // function does not need to do anything except return the encyptedDEK
// as it was received. // as it was received.
DecryptDEK(volumeID, encyptedDEK string) (string, error) DecryptDEK(ctx context.Context, volumeID, encyptedDEK string) (string, error)
// GetSecret allows external key management systems to // GetSecret allows external key management systems to
// retrieve keys used in EncryptDEK / DecryptDEK to use them // retrieve keys used in EncryptDEK / DecryptDEK to use them
// directly. Example: fscrypt uses this to unlock raw protectors // directly. Example: fscrypt uses this to unlock raw protectors
GetSecret(volumeID string) (string, error) GetSecret(ctx context.Context, volumeID string) (string, error)
} }
// DEKStoreType describes what DEKStore needs to be configured when using a // DEKStoreType describes what DEKStore needs to be configured when using a
@ -364,11 +364,11 @@ const (
// the KMS can not store passphrases for volumes. // the KMS can not store passphrases for volumes.
type DEKStore interface { type DEKStore interface {
// StoreDEK saves the DEK in the configured store. // StoreDEK saves the DEK in the configured store.
StoreDEK(volumeID string, dek string) error StoreDEK(ctx context.Context, volumeID string, dek string) error
// FetchDEK reads the DEK from the configured store and returns it. // FetchDEK reads the DEK from the configured store and returns it.
FetchDEK(volumeID string) (string, error) FetchDEK(ctx context.Context, volumeID string) (string, error)
// RemoveDEK deletes the DEK from the configured store. // RemoveDEK deletes the DEK from the configured store.
RemoveDEK(volumeID string) error RemoveDEK(ctx context.Context, volumeID string) error
} }
// integratedDEK is a DEKStore that can not be configured. Either the KMS does // integratedDEK is a DEKStore that can not be configured. Either the KMS does
@ -380,15 +380,15 @@ func (i integratedDEK) RequiresDEKStore() DEKStoreType {
return DEKStoreIntegrated return DEKStoreIntegrated
} }
func (i integratedDEK) EncryptDEK(volumeID, plainDEK string) (string, error) { func (i integratedDEK) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
return plainDEK, nil return plainDEK, nil
} }
func (i integratedDEK) DecryptDEK(volumeID, encyptedDEK string) (string, error) { func (i integratedDEK) DecryptDEK(ctx context.Context, volumeID, encyptedDEK string) (string, error) {
return encyptedDEK, nil return encyptedDEK, nil
} }
func (i integratedDEK) GetSecret(volumeID string) (string, error) { func (i integratedDEK) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretIntegrated return "", ErrGetSecretIntegrated
} }

View File

@ -78,19 +78,19 @@ func (kms secretsKMS) Destroy() {
} }
// FetchDEK returns passphrase from Kubernetes secrets. // FetchDEK returns passphrase from Kubernetes secrets.
func (kms secretsKMS) FetchDEK(key string) (string, error) { func (kms secretsKMS) FetchDEK(ctx context.Context, key string) (string, error) {
return kms.passphrase, nil return kms.passphrase, nil
} }
// StoreDEK does nothing, as there is no passphrase per key (volume), so // StoreDEK does nothing, as there is no passphrase per key (volume), so
// no need to store is anywhere. // no need to store is anywhere.
func (kms secretsKMS) StoreDEK(key, value string) error { func (kms secretsKMS) StoreDEK(ctx context.Context, key, value string) error {
return nil return nil
} }
// RemoveDEK is doing nothing as no new passphrases are saved with // RemoveDEK is doing nothing as no new passphrases are saved with
// secretsKMS. // secretsKMS.
func (kms secretsKMS) RemoveDEK(key string) error { func (kms secretsKMS) RemoveDEK(ctx context.Context, key string) error {
return nil return nil
} }
@ -206,9 +206,9 @@ type encryptedMetedataDEK struct {
// the secretsKMS and the volumeID. // the secretsKMS and the volumeID.
// The resulting encryptedDEK contains a JSON with the encrypted DEK and the // The resulting encryptedDEK contains a JSON with the encrypted DEK and the
// nonce that was used for encrypting. // nonce that was used for encrypting.
func (kms secretsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error) { func (kms secretsMetadataKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
// use the passphrase from the secretKMS // use the passphrase from the secretKMS
passphrase, err := kms.secretsKMS.FetchDEK(volumeID) passphrase, err := kms.secretsKMS.FetchDEK(ctx, volumeID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get passphrase: %w", err) return "", fmt.Errorf("failed to get passphrase: %w", err)
} }
@ -236,9 +236,9 @@ func (kms secretsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, err
// DecryptDEK takes the JSON formatted `encryptedMetadataDEK` contents, and it // DecryptDEK takes the JSON formatted `encryptedMetadataDEK` contents, and it
// fetches secretKMS passphrase to decrypt the DEK. // fetches secretKMS passphrase to decrypt the DEK.
func (kms secretsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) { func (kms secretsMetadataKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
// use the passphrase from the secretKMS // use the passphrase from the secretKMS
passphrase, err := kms.secretsKMS.FetchDEK(volumeID) passphrase, err := kms.secretsKMS.FetchDEK(ctx, volumeID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get passphrase: %w", err) return "", fmt.Errorf("failed to get passphrase: %w", err)
} }
@ -263,9 +263,9 @@ func (kms secretsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string,
return string(dek), nil return string(dek), nil
} }
func (kms secretsMetadataKMS) GetSecret(volumeID string) (string, error) { func (kms secretsMetadataKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
// use the passphrase from the secretKMS // use the passphrase from the secretKMS
return kms.secretsKMS.FetchDEK(volumeID) return kms.secretsKMS.FetchDEK(ctx, volumeID)
} }
// generateCipher returns a AEAD cipher based on a passphrase and salt // generateCipher returns a AEAD cipher based on a passphrase and salt

View File

@ -17,6 +17,7 @@ limitations under the License.
package kms package kms
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -103,19 +104,21 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
// plainDEK is the (LUKS) passphrase for the volume // plainDEK is the (LUKS) passphrase for the volume
plainDEK := "usually created with generateNewEncryptionPassphrase()" plainDEK := "usually created with generateNewEncryptionPassphrase()"
encryptedDEK, err := kms.EncryptDEK(volumeID, plainDEK) ctx := context.TODO()
encryptedDEK, err := kms.EncryptDEK(ctx, volumeID, plainDEK)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, "", encryptedDEK) assert.NotEqual(t, "", encryptedDEK)
assert.NotEqual(t, plainDEK, encryptedDEK) assert.NotEqual(t, plainDEK, encryptedDEK)
// with an incorrect volumeID, decrypting should fail // with an incorrect volumeID, decrypting should fail
decryptedDEK, err := kms.DecryptDEK("incorrect-volumeID", encryptedDEK) decryptedDEK, err := kms.DecryptDEK(ctx, "incorrect-volumeID", encryptedDEK)
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "", decryptedDEK) assert.Equal(t, "", decryptedDEK)
assert.NotEqual(t, plainDEK, decryptedDEK) assert.NotEqual(t, plainDEK, decryptedDEK)
// with the right volumeID, decrypting should return the plainDEK // with the right volumeID, decrypting should return the plainDEK
decryptedDEK, err = kms.DecryptDEK(volumeID, encryptedDEK) decryptedDEK, err = kms.DecryptDEK(ctx, volumeID, encryptedDEK)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, "", decryptedDEK) assert.NotEqual(t, "", decryptedDEK)
assert.Equal(t, plainDEK, decryptedDEK) assert.Equal(t, plainDEK, decryptedDEK)

View File

@ -17,6 +17,7 @@ limitations under the License.
package kms package kms
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -395,7 +396,7 @@ func initVaultKMS(args ProviderInitArgs) (EncryptionKMS, error) {
// FetchDEK returns passphrase from Vault. The passphrase is stored in a // FetchDEK returns passphrase from Vault. The passphrase is stored in a
// data.data.passphrase structure. // data.data.passphrase structure.
func (kms *vaultKMS) FetchDEK(key string) (string, error) { func (kms *vaultKMS) FetchDEK(ctx context.Context, key string) (string, error) {
// Since the second return variable loss.Version is not used, there it is ignored. // Since the second return variable loss.Version is not used, there it is ignored.
s, _, err := kms.secrets.GetSecret(filepath.Join(kms.vaultPassphrasePath, key), kms.keyContext) s, _, err := kms.secrets.GetSecret(filepath.Join(kms.vaultPassphrasePath, key), kms.keyContext)
if err != nil { if err != nil {
@ -415,7 +416,7 @@ func (kms *vaultKMS) FetchDEK(key string) (string, error) {
} }
// StoreDEK saves new passphrase in Vault. // StoreDEK saves new passphrase in Vault.
func (kms *vaultKMS) StoreDEK(key, value string) error { func (kms *vaultKMS) StoreDEK(ctx context.Context, key, value string) error {
data := map[string]interface{}{ data := map[string]interface{}{
"data": map[string]string{ "data": map[string]string{
"passphrase": value, "passphrase": value,
@ -433,7 +434,7 @@ func (kms *vaultKMS) StoreDEK(key, value string) error {
} }
// RemoveDEK deletes passphrase from Vault. // RemoveDEK deletes passphrase from Vault.
func (kms *vaultKMS) RemoveDEK(key string) error { func (kms *vaultKMS) RemoveDEK(ctx context.Context, key string) error {
pathKey := filepath.Join(kms.vaultPassphrasePath, key) pathKey := filepath.Join(kms.vaultPassphrasePath, key)
err := kms.secrets.DeleteSecret(pathKey, kms.getDeleteKeyContext()) err := kms.secrets.DeleteSecret(pathKey, kms.getDeleteKeyContext())
if err != nil { if err != nil {

View File

@ -459,7 +459,7 @@ func (vtc *vaultTenantConnection) getK8sClient() (*kubernetes.Clientset, error)
// FetchDEK returns passphrase from Vault. The passphrase is stored in a // FetchDEK returns passphrase from Vault. The passphrase is stored in a
// data.data.passphrase structure. // data.data.passphrase structure.
func (vtc *vaultTenantConnection) FetchDEK(key string) (string, error) { func (vtc *vaultTenantConnection) FetchDEK(ctx context.Context, key string) (string, error) {
// Since the second return variable loss.Version is not used, there it is ignored. // Since the second return variable loss.Version is not used, there it is ignored.
s, _, err := vtc.secrets.GetSecret(key, vtc.keyContext) s, _, err := vtc.secrets.GetSecret(key, vtc.keyContext)
if err != nil { if err != nil {
@ -479,7 +479,7 @@ func (vtc *vaultTenantConnection) FetchDEK(key string) (string, error) {
} }
// StoreDEK saves new passphrase in Vault. // StoreDEK saves new passphrase in Vault.
func (vtc *vaultTenantConnection) StoreDEK(key, value string) error { func (vtc *vaultTenantConnection) StoreDEK(ctx context.Context, key, value string) error {
data := map[string]interface{}{ data := map[string]interface{}{
"data": map[string]string{ "data": map[string]string{
"passphrase": value, "passphrase": value,
@ -496,7 +496,7 @@ func (vtc *vaultTenantConnection) StoreDEK(key, value string) error {
} }
// RemoveDEK deletes passphrase from Vault. // RemoveDEK deletes passphrase from Vault.
func (vtc *vaultTenantConnection) RemoveDEK(key string) error { func (vtc *vaultTenantConnection) RemoveDEK(ctx context.Context, key string) error {
err := vtc.secrets.DeleteSecret(key, vtc.getDeleteKeyContext()) err := vtc.secrets.DeleteSecret(key, vtc.getDeleteKeyContext())
if err != nil { if err != nil {
return fmt.Errorf("delete passphrase at %s request to vault failed: %w", key, err) return fmt.Errorf("delete passphrase at %s request to vault failed: %w", key, err)

View File

@ -155,7 +155,7 @@ func (rv *rbdVolume) createCloneFromImage(ctx context.Context, parentVol *rbdVol
return err return err
} }
err = parentVol.copyEncryptionConfig(&rv.rbdImage, true) err = parentVol.copyEncryptionConfig(ctx, &rv.rbdImage, true)
if err != nil { if err != nil {
return fmt.Errorf("failed to copy encryption config for %q: %w", rv, err) return fmt.Errorf("failed to copy encryption config for %q: %w", rv, err)
} }
@ -232,7 +232,7 @@ func (rv *rbdVolume) doSnapClone(ctx context.Context, parentVol *rbdVolume) erro
return errClone return errClone
} }
err = parentVol.copyEncryptionConfig(&rv.rbdImage, true) err = parentVol.copyEncryptionConfig(ctx, &rv.rbdImage, true)
if err != nil { if err != nil {
return fmt.Errorf("failed to copy encryption config for %q: %w", rv, err) return fmt.Errorf("failed to copy encryption config for %q: %w", rv, err)
} }

View File

@ -191,7 +191,7 @@ func (cs *ControllerServer) parseVolCreateRequest(
// get the owner of the PVC which is required for few encryption related operations // get the owner of the PVC which is required for few encryption related operations
rbdVol.Owner = k8s.GetOwner(req.GetParameters()) rbdVol.Owner = k8s.GetOwner(req.GetParameters())
err = rbdVol.initKMS(req.GetParameters(), req.GetSecrets()) err = rbdVol.initKMS(ctx, req.GetParameters(), req.GetSecrets())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
@ -486,7 +486,7 @@ func (cs *ControllerServer) repairExistingVolume(ctx context.Context, req *csi.C
return nil, err return nil, err
} }
err = rbdSnap.repairEncryptionConfig(&rbdVol.rbdImage) err = rbdSnap.repairEncryptionConfig(ctx, &rbdVol.rbdImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -677,7 +677,7 @@ func (cs *ControllerServer) createVolumeFromSnapshot(
log.DebugLog(ctx, "create volume %s from snapshot %s", rbdVol, rbdSnap) log.DebugLog(ctx, "create volume %s from snapshot %s", rbdVol, rbdSnap)
err = parentVol.copyEncryptionConfig(&rbdVol.rbdImage, true) err = parentVol.copyEncryptionConfig(ctx, &rbdVol.rbdImage, true)
if err != nil { if err != nil {
return fmt.Errorf("failed to copy encryption config for %q: %w", rbdVol, err) return fmt.Errorf("failed to copy encryption config for %q: %w", rbdVol, err)
} }
@ -1229,7 +1229,7 @@ func cloneFromSnapshot(
} }
defer vol.Destroy() defer vol.Destroy()
err = rbdVol.copyEncryptionConfig(&vol.rbdImage, false) err = rbdVol.copyEncryptionConfig(ctx, &vol.rbdImage, false)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }
@ -1332,7 +1332,7 @@ func (cs *ControllerServer) doSnapshotClone(
} }
}() }()
err = parentVol.copyEncryptionConfig(&cloneRbd.rbdImage, false) err = parentVol.copyEncryptionConfig(ctx, &cloneRbd.rbdImage, false)
if err != nil { if err != nil {
log.ErrorLog(ctx, "failed to copy encryption "+ log.ErrorLog(ctx, "failed to copy encryption "+
"config for %q: %v", cloneRbd, err) "config for %q: %v", cloneRbd, err)

View File

@ -116,7 +116,7 @@ func IsFileEncrypted(ctx context.Context, volOptions map[string]string) (bool, e
// - the Data-Encryption-Key (DEK) will be generated stored for use by the KMS; // - the Data-Encryption-Key (DEK) will be generated stored for use by the KMS;
// - the RBD image will be marked to support encryption in its metadata. // - the RBD image will be marked to support encryption in its metadata.
func (ri *rbdImage) setupBlockEncryption(ctx context.Context) error { func (ri *rbdImage) setupBlockEncryption(ctx context.Context) error {
err := ri.blockEncryption.StoreNewCryptoPassphrase(ri.VolID, encryptionPassphraseSize) err := ri.blockEncryption.StoreNewCryptoPassphrase(ctx, ri.VolID, encryptionPassphraseSize)
if err != nil { if err != nil {
log.ErrorLog(ctx, "failed to save encryption passphrase for "+ log.ErrorLog(ctx, "failed to save encryption passphrase for "+
"image %s: %s", ri, err) "image %s: %s", ri, err)
@ -144,7 +144,7 @@ func (ri *rbdImage) setupBlockEncryption(ctx context.Context) error {
// destination rbdImage's VolumeEncryption object which needs to be initialized // destination rbdImage's VolumeEncryption object which needs to be initialized
// beforehand and is possibly different from the source VolumeEncryption // beforehand and is possibly different from the source VolumeEncryption
// (Usecase: Restoring snapshot into a storageclass with different encryption config). // (Usecase: Restoring snapshot into a storageclass with different encryption config).
func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool) error { func (ri *rbdImage) copyEncryptionConfig(ctx context.Context, cp *rbdImage, copyOnlyPassphrase bool) error {
// nothing to do if parent image is not encrypted. // nothing to do if parent image is not encrypted.
if !ri.isBlockEncrypted() && !ri.isFileEncrypted() { if !ri.isBlockEncrypted() && !ri.isFileEncrypted() {
return nil return nil
@ -157,7 +157,7 @@ func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool)
if ri.isBlockEncrypted() { if ri.isBlockEncrypted() {
// get the unencrypted passphrase // get the unencrypted passphrase
passphrase, err := ri.blockEncryption.GetCryptoPassphrase(ri.VolID) passphrase, err := ri.blockEncryption.GetCryptoPassphrase(ctx, ri.VolID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch passphrase for %q: %w", return fmt.Errorf("failed to fetch passphrase for %q: %w",
ri, err) ri, err)
@ -171,7 +171,7 @@ func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool)
} }
// re-encrypt the plain passphrase for the cloned volume // re-encrypt the plain passphrase for the cloned volume
err = cp.blockEncryption.StoreCryptoPassphrase(cp.VolID, passphrase) err = cp.blockEncryption.StoreCryptoPassphrase(ctx, cp.VolID, passphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to store passphrase for %q: %w", return fmt.Errorf("failed to store passphrase for %q: %w",
cp, err) cp, err)
@ -182,7 +182,7 @@ func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool)
var err error var err error
cp.fileEncryption, err = util.NewVolumeEncryption(ri.fileEncryption.GetID(), ri.fileEncryption.KMS) cp.fileEncryption, err = util.NewVolumeEncryption(ri.fileEncryption.GetID(), ri.fileEncryption.KMS)
if errors.Is(err, util.ErrDEKStoreNeeded) { if errors.Is(err, util.ErrDEKStoreNeeded) {
_, err := ri.fileEncryption.KMS.GetSecret("") _, err := ri.fileEncryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) { if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err return err
} }
@ -191,14 +191,14 @@ func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool)
if ri.isFileEncrypted() && ri.fileEncryption.KMS.RequiresDEKStore() == kmsapi.DEKStoreIntegrated { if ri.isFileEncrypted() && ri.fileEncryption.KMS.RequiresDEKStore() == kmsapi.DEKStoreIntegrated {
// get the unencrypted passphrase // get the unencrypted passphrase
passphrase, err := ri.fileEncryption.GetCryptoPassphrase(ri.VolID) passphrase, err := ri.fileEncryption.GetCryptoPassphrase(ctx, ri.VolID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch passphrase for %q: %w", return fmt.Errorf("failed to fetch passphrase for %q: %w",
ri, err) ri, err)
} }
// re-encrypt the plain passphrase for the cloned volume // re-encrypt the plain passphrase for the cloned volume
err = cp.fileEncryption.StoreCryptoPassphrase(cp.VolID, passphrase) err = cp.fileEncryption.StoreCryptoPassphrase(ctx, cp.VolID, passphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to store passphrase for %q: %w", return fmt.Errorf("failed to store passphrase for %q: %w",
cp, err) cp, err)
@ -223,7 +223,7 @@ func (ri *rbdImage) copyEncryptionConfig(cp *rbdImage, copyOnlyPassphrase bool)
// repairEncryptionConfig checks the encryption state of the current rbdImage, // repairEncryptionConfig checks the encryption state of the current rbdImage,
// and makes sure that the destination rbdImage has the same configuration. // and makes sure that the destination rbdImage has the same configuration.
func (ri *rbdImage) repairEncryptionConfig(dest *rbdImage) error { func (ri *rbdImage) repairEncryptionConfig(ctx context.Context, dest *rbdImage) error {
if !ri.isBlockEncrypted() && !ri.isFileEncrypted() { if !ri.isBlockEncrypted() && !ri.isFileEncrypted() {
return nil return nil
} }
@ -236,14 +236,14 @@ func (ri *rbdImage) repairEncryptionConfig(dest *rbdImage) error {
dest.conn = ri.conn.Copy() dest.conn = ri.conn.Copy()
} }
return ri.copyEncryptionConfig(dest, true) return ri.copyEncryptionConfig(ctx, dest, true)
} }
return nil return nil
} }
func (ri *rbdImage) encryptDevice(ctx context.Context, devicePath string) error { func (ri *rbdImage) encryptDevice(ctx context.Context, devicePath string) error {
passphrase, err := ri.blockEncryption.GetCryptoPassphrase(ri.VolID) passphrase, err := ri.blockEncryption.GetCryptoPassphrase(ctx, ri.VolID)
if err != nil { if err != nil {
log.ErrorLog(ctx, "failed to get crypto passphrase for %s: %v", log.ErrorLog(ctx, "failed to get crypto passphrase for %s: %v",
ri, err) ri, err)
@ -269,7 +269,7 @@ func (ri *rbdImage) encryptDevice(ctx context.Context, devicePath string) error
} }
func (rv *rbdVolume) openEncryptedDevice(ctx context.Context, devicePath string) (string, error) { func (rv *rbdVolume) openEncryptedDevice(ctx context.Context, devicePath string) (string, error) {
passphrase, err := rv.blockEncryption.GetCryptoPassphrase(rv.VolID) passphrase, err := rv.blockEncryption.GetCryptoPassphrase(ctx, rv.VolID)
if err != nil { if err != nil {
log.ErrorLog(ctx, "failed to get passphrase for encrypted device %s: %v", log.ErrorLog(ctx, "failed to get passphrase for encrypted device %s: %v",
rv, err) rv, err)
@ -300,7 +300,7 @@ func (rv *rbdVolume) openEncryptedDevice(ctx context.Context, devicePath string)
return mapperFilePath, nil return mapperFilePath, nil
} }
func (ri *rbdImage) initKMS(volOptions, credentials map[string]string) error { func (ri *rbdImage) initKMS(ctx context.Context, volOptions, credentials map[string]string) error {
kmsID, encType, err := ParseEncryptionOpts(volOptions, rbdDefaultEncryptionType) kmsID, encType, err := ParseEncryptionOpts(volOptions, rbdDefaultEncryptionType)
if err != nil { if err != nil {
return err return err
@ -310,7 +310,7 @@ func (ri *rbdImage) initKMS(volOptions, credentials map[string]string) error {
case util.EncryptionTypeBlock: case util.EncryptionTypeBlock:
err = ri.configureBlockEncryption(kmsID, credentials) err = ri.configureBlockEncryption(kmsID, credentials)
case util.EncryptionTypeFile: case util.EncryptionTypeFile:
err = ri.configureFileEncryption(kmsID, credentials) err = ri.configureFileEncryption(ctx, kmsID, credentials)
case util.EncryptionTypeInvalid: case util.EncryptionTypeInvalid:
return fmt.Errorf("invalid encryption type") return fmt.Errorf("invalid encryption type")
case util.EncryptionTypeNone: case util.EncryptionTypeNone:
@ -376,7 +376,7 @@ func (ri *rbdImage) configureBlockEncryption(kmsID string, credentials map[strin
// configureBlockDeviceEncryption sets up the VolumeEncryption for this rbdImage. Once // configureBlockDeviceEncryption sets up the VolumeEncryption for this rbdImage. Once
// configured, use isEncrypted() to see if the volume supports encryption. // configured, use isEncrypted() to see if the volume supports encryption.
func (ri *rbdImage) configureFileEncryption(kmsID string, credentials map[string]string) error { func (ri *rbdImage) configureFileEncryption(ctx context.Context, kmsID string, credentials map[string]string) error {
kms, err := kmsapi.GetKMS(ri.Owner, kmsID, credentials) kms, err := kmsapi.GetKMS(ri.Owner, kmsID, credentials)
if err != nil { if err != nil {
return err return err
@ -390,7 +390,7 @@ func (ri *rbdImage) configureFileEncryption(kmsID string, credentials map[string
// store. Since not all "metadata" KMS support // store. Since not all "metadata" KMS support
// GetSecret, test for support here. Postpone any // GetSecret, test for support here. Postpone any
// other error handling // other error handling
_, err := ri.fileEncryption.KMS.GetSecret("") _, err := ri.fileEncryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) { if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err return err
} }
@ -400,7 +400,7 @@ func (ri *rbdImage) configureFileEncryption(kmsID string, credentials map[string
} }
// StoreDEK saves the DEK in the metadata, overwrites any existing contents. // StoreDEK saves the DEK in the metadata, overwrites any existing contents.
func (ri *rbdImage) StoreDEK(volumeID, dek string) error { func (ri *rbdImage) StoreDEK(ctx context.Context, volumeID, dek string) error {
if ri.VolID == "" { if ri.VolID == "" {
return fmt.Errorf("BUG: %q does not have VolID set, call "+ return fmt.Errorf("BUG: %q does not have VolID set, call "+
"stack: %s", ri, util.CallStack()) "stack: %s", ri, util.CallStack())
@ -413,7 +413,7 @@ func (ri *rbdImage) StoreDEK(volumeID, dek string) error {
} }
// FetchDEK reads the DEK from the image metadata. // FetchDEK reads the DEK from the image metadata.
func (ri *rbdImage) FetchDEK(volumeID string) (string, error) { func (ri *rbdImage) FetchDEK(ctx context.Context, volumeID string) (string, error) {
if ri.VolID == "" { if ri.VolID == "" {
return "", fmt.Errorf("BUG: %q does not have VolID set, call "+ return "", fmt.Errorf("BUG: %q does not have VolID set, call "+
"stack: %s", ri, util.CallStack()) "stack: %s", ri, util.CallStack())
@ -426,7 +426,7 @@ func (ri *rbdImage) FetchDEK(volumeID string) (string, error) {
// RemoveDEK does not need to remove the DEK from the metadata, the image is // RemoveDEK does not need to remove the DEK from the metadata, the image is
// most likely getting removed. // most likely getting removed.
func (ri *rbdImage) RemoveDEK(volumeID string) error { func (ri *rbdImage) RemoveDEK(ctx context.Context, volumeID string) error {
if ri.VolID == "" { if ri.VolID == "" {
return fmt.Errorf("BUG: %q does not have VolID set, call "+ return fmt.Errorf("BUG: %q does not have VolID set, call "+
"stack: %s", ri, util.CallStack()) "stack: %s", ri, util.CallStack())

View File

@ -232,7 +232,7 @@ func (ns *NodeServer) populateRbdVol(
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }
err = rv.initKMS(req.GetVolumeContext(), req.GetSecrets()) err = rv.initKMS(ctx, req.GetVolumeContext(), req.GetSecrets())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }

View File

@ -334,7 +334,7 @@ func (rv *rbdVolume) Exists(ctx context.Context, parentVol *rbdVolume) (bool, er
} }
if parentVol != nil { if parentVol != nil {
err = parentVol.copyEncryptionConfig(&rv.rbdImage, true) err = parentVol.copyEncryptionConfig(ctx, &rv.rbdImage, true)
if err != nil { if err != nil {
log.ErrorLog(ctx, err.Error()) log.ErrorLog(ctx, err.Error())

View File

@ -635,14 +635,14 @@ func (ri *rbdImage) deleteImage(ctx context.Context) error {
if ri.isBlockEncrypted() { if ri.isBlockEncrypted() {
log.DebugLog(ctx, "rbd: going to remove DEK for %q (block encryption)", ri) log.DebugLog(ctx, "rbd: going to remove DEK for %q (block encryption)", ri)
if err = ri.blockEncryption.RemoveDEK(ri.VolID); err != nil { if err = ri.blockEncryption.RemoveDEK(ctx, ri.VolID); err != nil {
log.WarningLog(ctx, "failed to clean the passphrase for volume %s (block encryption): %s", ri.VolID, err) log.WarningLog(ctx, "failed to clean the passphrase for volume %s (block encryption): %s", ri.VolID, err)
} }
} }
if ri.isFileEncrypted() { if ri.isFileEncrypted() {
log.DebugLog(ctx, "rbd: going to remove DEK for %q (file encryption)", ri) log.DebugLog(ctx, "rbd: going to remove DEK for %q (file encryption)", ri)
if err = ri.fileEncryption.RemoveDEK(ri.VolID); err != nil { if err = ri.fileEncryption.RemoveDEK(ctx, ri.VolID); err != nil {
log.WarningLog(ctx, "failed to clean the passphrase for volume %s (file encryption): %s", ri.VolID, err) log.WarningLog(ctx, "failed to clean the passphrase for volume %s (file encryption): %s", ri.VolID, err)
} }
} }
@ -1032,7 +1032,7 @@ func genSnapFromSnapID(
} }
} }
if imageAttributes.KmsID != "" && imageAttributes.EncryptionType == util.EncryptionTypeFile { if imageAttributes.KmsID != "" && imageAttributes.EncryptionType == util.EncryptionTypeFile {
err = rbdSnap.configureFileEncryption(imageAttributes.KmsID, secrets) err = rbdSnap.configureFileEncryption(ctx, imageAttributes.KmsID, secrets)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure file encryption for "+ return fmt.Errorf("failed to configure file encryption for "+
"%q: %w", rbdSnap, err) "%q: %w", rbdSnap, err)
@ -1133,7 +1133,7 @@ func generateVolumeFromVolumeID(
} }
} }
if imageAttributes.KmsID != "" && imageAttributes.EncryptionType == util.EncryptionTypeFile { if imageAttributes.KmsID != "" && imageAttributes.EncryptionType == util.EncryptionTypeFile {
err = rbdVol.configureFileEncryption(imageAttributes.KmsID, secrets) err = rbdVol.configureFileEncryption(ctx, imageAttributes.KmsID, secrets)
if err != nil { if err != nil {
return rbdVol, err return rbdVol, err
} }

View File

@ -189,12 +189,12 @@ func (ve *VolumeEncryption) Destroy() {
// RemoveDEK deletes the DEK for a particular volumeID from the DEKStore linked // RemoveDEK deletes the DEK for a particular volumeID from the DEKStore linked
// with this VolumeEncryption instance. // with this VolumeEncryption instance.
func (ve *VolumeEncryption) RemoveDEK(volumeID string) error { func (ve *VolumeEncryption) RemoveDEK(ctx context.Context, volumeID string) error {
if ve.dekStore == nil { if ve.dekStore == nil {
return ErrDEKStoreNotFound return ErrDEKStoreNotFound
} }
return ve.dekStore.RemoveDEK(volumeID) return ve.dekStore.RemoveDEK(ctx, volumeID)
} }
func (ve *VolumeEncryption) GetID() string { func (ve *VolumeEncryption) GetID() string {
@ -203,13 +203,13 @@ func (ve *VolumeEncryption) GetID() string {
// StoreCryptoPassphrase takes an unencrypted passphrase, encrypts it and saves // StoreCryptoPassphrase takes an unencrypted passphrase, encrypts it and saves
// it in the DEKStore. // it in the DEKStore.
func (ve *VolumeEncryption) StoreCryptoPassphrase(volumeID, passphrase string) error { func (ve *VolumeEncryption) StoreCryptoPassphrase(ctx context.Context, volumeID, passphrase string) error {
encryptedPassphrase, err := ve.KMS.EncryptDEK(volumeID, passphrase) encryptedPassphrase, err := ve.KMS.EncryptDEK(ctx, volumeID, passphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed encrypt the passphrase for %s: %w", volumeID, err) return fmt.Errorf("failed encrypt the passphrase for %s: %w", volumeID, err)
} }
err = ve.dekStore.StoreDEK(volumeID, encryptedPassphrase) err = ve.dekStore.StoreDEK(ctx, volumeID, encryptedPassphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to save the passphrase for %s: %w", volumeID, err) return fmt.Errorf("failed to save the passphrase for %s: %w", volumeID, err)
} }
@ -218,23 +218,23 @@ func (ve *VolumeEncryption) StoreCryptoPassphrase(volumeID, passphrase string) e
} }
// StoreNewCryptoPassphrase generates a new passphrase and saves it in the KMS. // StoreNewCryptoPassphrase generates a new passphrase and saves it in the KMS.
func (ve *VolumeEncryption) StoreNewCryptoPassphrase(volumeID string, length int) error { func (ve *VolumeEncryption) StoreNewCryptoPassphrase(ctx context.Context, volumeID string, length int) error {
passphrase, err := generateNewEncryptionPassphrase(length) passphrase, err := generateNewEncryptionPassphrase(length)
if err != nil { if err != nil {
return fmt.Errorf("failed to generate passphrase for %s: %w", volumeID, err) return fmt.Errorf("failed to generate passphrase for %s: %w", volumeID, err)
} }
return ve.StoreCryptoPassphrase(volumeID, passphrase) return ve.StoreCryptoPassphrase(ctx, volumeID, passphrase)
} }
// GetCryptoPassphrase Retrieves passphrase to encrypt volume. // GetCryptoPassphrase Retrieves passphrase to encrypt volume.
func (ve *VolumeEncryption) GetCryptoPassphrase(volumeID string) (string, error) { func (ve *VolumeEncryption) GetCryptoPassphrase(ctx context.Context, volumeID string) (string, error) {
passphrase, err := ve.dekStore.FetchDEK(volumeID) passphrase, err := ve.dekStore.FetchDEK(ctx, volumeID)
if err != nil { if err != nil {
return "", err return "", err
} }
return ve.KMS.DecryptDEK(volumeID, passphrase) return ve.KMS.DecryptDEK(ctx, volumeID, passphrase)
} }
// generateNewEncryptionPassphrase generates a random passphrase for encryption. // generateNewEncryptionPassphrase generates a random passphrase for encryption.

View File

@ -17,6 +17,7 @@ limitations under the License.
package util package util
import ( import (
"context"
"encoding/base64" "encoding/base64"
"testing" "testing"
@ -55,11 +56,12 @@ func TestKMSWorkflow(t *testing.T) {
assert.Equal(t, kms.DefaultKMSType, ve.GetID()) assert.Equal(t, kms.DefaultKMSType, ve.GetID())
volumeID := "volume-id" volumeID := "volume-id"
ctx := context.TODO()
err = ve.StoreNewCryptoPassphrase(volumeID, defaultEncryptionPassphraseSize) err = ve.StoreNewCryptoPassphrase(ctx, volumeID, defaultEncryptionPassphraseSize)
assert.NoError(t, err) assert.NoError(t, err)
passphrase, err := ve.GetCryptoPassphrase(volumeID) passphrase, err := ve.GetCryptoPassphrase(ctx, volumeID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, secrets["encryptionPassphrase"], passphrase) assert.Equal(t, secrets["encryptionPassphrase"], passphrase)
} }

View File

@ -76,14 +76,14 @@ func getPassphrase(ctx context.Context, encryption util.VolumeEncryption, volID
switch encryption.KMS.RequiresDEKStore() { switch encryption.KMS.RequiresDEKStore() {
case kms.DEKStoreIntegrated: case kms.DEKStoreIntegrated:
passphrase, err = encryption.GetCryptoPassphrase(volID) passphrase, err = encryption.GetCryptoPassphrase(ctx, volID)
if err != nil { if err != nil {
log.ErrorLog(ctx, "fscrypt: failed to get passphrase from KMS: %v", err) log.ErrorLog(ctx, "fscrypt: failed to get passphrase from KMS: %v", err)
return "", err return "", err
} }
case kms.DEKStoreMetadata: case kms.DEKStoreMetadata:
passphrase, err = encryption.KMS.GetSecret(volID) passphrase, err = encryption.KMS.GetSecret(ctx, volID)
if err != nil { if err != nil {
log.ErrorLog(ctx, "fscrypt: failed to GetSecret: %v", err) log.ErrorLog(ctx, "fscrypt: failed to GetSecret: %v", err)
@ -453,7 +453,7 @@ func Unlock(
if !kernelPolicyExists && !metadataDirExists { if !kernelPolicyExists && !metadataDirExists {
log.DebugLog(ctx, "fscrypt: Creating new protector and policy") log.DebugLog(ctx, "fscrypt: Creating new protector and policy")
if volEncryption.KMS.RequiresDEKStore() == kms.DEKStoreIntegrated { if volEncryption.KMS.RequiresDEKStore() == kms.DEKStoreIntegrated {
if err := volEncryption.StoreNewCryptoPassphrase(volID, encryptionPassphraseSize); err != nil { if err := volEncryption.StoreNewCryptoPassphrase(ctx, volID, encryptionPassphraseSize); err != nil {
log.ErrorLog(ctx, "fscrypt: store new crypto passphrase failed: %v", err) log.ErrorLog(ctx, "fscrypt: store new crypto passphrase failed: %v", err)
return err return err

View File

@ -14,6 +14,7 @@ limitations under the License.
package util package util
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -34,7 +35,7 @@ func TestGetPassphraseFromKMS(t *testing.T) {
volEnc, err := NewVolumeEncryption(provider.UniqueID, kms) volEnc, err := NewVolumeEncryption(provider.UniqueID, kms)
if errors.Is(err, ErrDEKStoreNeeded) { if errors.Is(err, ErrDEKStoreNeeded) {
_, err = volEnc.KMS.GetSecret("") _, err = volEnc.KMS.GetSecret(context.TODO(), "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) { if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
continue // currently unsupported by fscrypt integration continue // currently unsupported by fscrypt integration
} }
@ -45,7 +46,7 @@ func TestGetPassphraseFromKMS(t *testing.T) {
continue continue
} }
secret, err := kms.GetSecret("") secret, err := kms.GetSecret(context.TODO(), "")
assert.NoError(t, err, provider.UniqueID) assert.NoError(t, err, provider.UniqueID)
assert.NotEmpty(t, secret, provider.UniqueID) assert.NotEmpty(t, secret, provider.UniqueID)
} }

View File

@ -28,6 +28,8 @@ var (
uint32Type = reflect.TypeOf(uint32(1)) uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1)) uint64Type = reflect.TypeOf(uint64(1))
uintptrType = reflect.TypeOf(uintptr(1))
float32Type = reflect.TypeOf(float32(1)) float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1)) float64Type = reflect.TypeOf(float64(1))
@ -308,11 +310,11 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Struct: case reflect.Struct:
{ {
// All structs enter here. We're not interested in most types. // All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) { if !obj1Value.CanConvert(timeType) {
break break
} }
// time.Time can compared! // time.Time can be compared!
timeObj1, ok := obj1.(time.Time) timeObj1, ok := obj1.(time.Time)
if !ok { if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
@ -328,7 +330,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Slice: case reflect.Slice:
{ {
// We only care about the []byte type. // We only care about the []byte type.
if !canConvert(obj1Value, bytesType) { if !obj1Value.CanConvert(bytesType) {
break break
} }
@ -345,6 +347,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
} }
case reflect.Uintptr:
{
uintptrObj1, ok := obj1.(uintptr)
if !ok {
uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
}
uintptrObj2, ok := obj2.(uintptr)
if !ok {
uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
}
if uintptrObj1 > uintptrObj2 {
return compareGreater, true
}
if uintptrObj1 == uintptrObj2 {
return compareEqual, true
}
if uintptrObj1 < uintptrObj2 {
return compareLess, true
}
}
} }
return compareEqual, false return compareEqual, false

View File

@ -1,16 +0,0 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@ -1,16 +0,0 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -107,7 +104,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -616,6 +613,16 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...) return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
} }
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil. // NotNilf asserts that the specified object is not nil.
// //
// assert.NotNilf(t, err, "error message %s", "formatted") // assert.NotNilf(t, err, "error message %s", "formatted")
@ -660,10 +667,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -747,10 +756,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...) return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -189,7 +186,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...) return EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
@ -200,7 +197,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...) return EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1221,6 +1218,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
return NotErrorIsf(a.t, err, target, msg, args...) return NotErrorIsf(a.t, err, target, msg, args...)
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// a.NotNil(err) // a.NotNil(err)
@ -1309,10 +1326,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...) return NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1320,10 +1339,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...) return NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1483,10 +1504,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...) return Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1494,10 +1516,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...) return Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -19,7 +19,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib" "github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -110,7 +110,12 @@ func copyExportedFields(expected interface{}) interface{} {
return result.Interface() return result.Interface()
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) var result reflect.Value
if expectedKind == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
}
for i := 0; i < expectedValue.Len(); i++ { for i := 0; i < expectedValue.Len(); i++ {
index := expectedValue.Index(i) index := expectedValue.Index(i)
if isNil(index) { if isNil(index) {
@ -140,6 +145,8 @@ func copyExportedFields(expected interface{}) interface{} {
// structures. // structures.
// //
// This function does no assertion of any kind. // This function does no assertion of any kind.
//
// Deprecated: Use [EqualExportedValues] instead.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
expectedCleaned := copyExportedFields(expected) expectedCleaned := copyExportedFields(expected)
actualCleaned := copyExportedFields(actual) actualCleaned := copyExportedFields(actual)
@ -153,17 +160,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
return true return true
} }
actualType := reflect.TypeOf(actual) expectedValue := reflect.ValueOf(expected)
if actualType == nil { actualValue := reflect.ValueOf(actual)
if !expectedValue.IsValid() || !actualValue.IsValid() {
return false return false
} }
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { expectedType := expectedValue.Type()
// Attempt comparison after type conversion actualType := actualValue.Type()
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) if !expectedType.ConvertibleTo(actualType) {
return false
} }
return false if !isNumericType(expectedType) || !isNumericType(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(
expectedValue.Convert(actualType).Interface(), actual,
)
}
// If BOTH values are numeric, there are chances of false positives due
// to overflow or underflow. So, we need to make sure to always convert
// the smaller type to a larger type before comparing.
if expectedType.Size() >= actualType.Size() {
return actualValue.Convert(expectedType).Interface() == expected
}
return expectedValue.Convert(actualType).Interface() == actual
}
// isNumericType returns true if the type is one of:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
// float32, float64, complex64, complex128
func isNumericType(t reflect.Type) bool {
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
} }
/* CallerInfo is necessary because the assert functions use the testing object /* CallerInfo is necessary because the assert functions use the testing object
@ -266,7 +296,7 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
// Aligns the provided message so that all lines after the first line start at the same location as the first line. // Aligns the provided message so that all lines after the first line start at the same location as the first line.
// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). // Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab).
// The longestLabelLen parameter specifies the length of the longest label in the output (required becaues this is the // The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the
// basis on which the alignment occurs). // basis on which the alignment occurs).
func indentMessageLines(message string, longestLabelLen int) string { func indentMessageLines(message string, longestLabelLen int) string {
outBuf := new(bytes.Buffer) outBuf := new(bytes.Buffer)
@ -382,6 +412,25 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg
return true return true
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
interfaceType := reflect.TypeOf(interfaceObject).Elem()
if object == nil {
return Fail(t, fmt.Sprintf("Cannot check if nil does not implement %v", interfaceType), msgAndArgs...)
}
if reflect.TypeOf(object).Implements(interfaceType) {
return Fail(t, fmt.Sprintf("%T implements %v", object, interfaceType), msgAndArgs...)
}
return true
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -496,7 +545,7 @@ func samePointers(first, second interface{}) bool {
// representations appropriate to be presented to the user. // representations appropriate to be presented to the user.
// //
// If the values are not of like type, the returned strings will be prefixed // If the values are not of like type, the returned strings will be prefixed
// with the type name, and the value will be enclosed in parenthesis similar // with the type name, and the value will be enclosed in parentheses similar
// to a type conversion in the Go grammar. // to a type conversion in the Go grammar.
func formatUnequalValues(expected, actual interface{}) (e string, a string) { func formatUnequalValues(expected, actual interface{}) (e string, a string) {
if reflect.TypeOf(expected) != reflect.TypeOf(actual) { if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
@ -523,7 +572,7 @@ func truncatingFormat(data interface{}) string {
return value return value
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
@ -566,12 +615,19 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
} }
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct { if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
} }
if bType.Kind() != reflect.Struct { if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
} }
expected = copyExportedFields(expected) expected = copyExportedFields(expected)
@ -620,17 +676,6 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, "Expected value not to be nil.", msgAndArgs...) return Fail(t, "Expected value not to be nil.", msgAndArgs...)
} }
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing. // isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool { func isNil(object interface{}) bool {
if object == nil { if object == nil {
@ -638,16 +683,13 @@ func isNil(object interface{}) bool {
} }
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind() switch value.Kind() {
isNilableKind := containsKind( case
[]reflect.Kind{ reflect.Chan, reflect.Func,
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
reflect.Ptr, reflect.Slice, reflect.UnsafePointer},
kind)
if isNilableKind && value.IsNil() { return value.IsNil()
return true
} }
return false return false
@ -731,16 +773,14 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
} }
// getLen try to get length of object. // getLen tries to get the length of an object.
// return (false, 0) if impossible. // It returns (0, false) if impossible.
func getLen(x interface{}) (ok bool, length int) { func getLen(x interface{}) (length int, ok bool) {
v := reflect.ValueOf(x) v := reflect.ValueOf(x)
defer func() { defer func() {
if e := recover(); e != nil { ok = recover() == nil
ok = false
}
}() }()
return true, v.Len() return v.Len(), true
} }
// Len asserts that the specified object has specific length. // Len asserts that the specified object has specific length.
@ -751,13 +791,13 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
ok, l := getLen(object) l, ok := getLen(object)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" could not be applied builtin len()", object), msgAndArgs...)
} }
if l != length { if l != length {
return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" should have %d item(s), but has %d", object, length, l), msgAndArgs...)
} }
return true return true
} }
@ -919,10 +959,11 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -975,10 +1016,12 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true return true
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1439,7 +1482,7 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
h.Helper() h.Helper()
} }
if math.IsNaN(epsilon) { if math.IsNaN(epsilon) {
return Fail(t, "epsilon must not be NaN") return Fail(t, "epsilon must not be NaN", msgAndArgs...)
} }
actualEpsilon, err := calcRelativeError(expected, actual) actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil { if err != nil {
@ -1458,19 +1501,26 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || if expected == nil || actual == nil {
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, "Parameters must be slice", msgAndArgs...) return Fail(t, "Parameters must be slice", msgAndArgs...)
} }
actualSlice := reflect.ValueOf(actual)
expectedSlice := reflect.ValueOf(expected) expectedSlice := reflect.ValueOf(expected)
actualSlice := reflect.ValueOf(actual)
for i := 0; i < actualSlice.Len(); i++ { if expectedSlice.Type().Kind() != reflect.Slice {
result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon) return Fail(t, "Expected value must be slice", msgAndArgs...)
if !result { }
return result
expectedLen := expectedSlice.Len()
if !IsType(t, expected, actual) || !Len(t, actual, expectedLen) {
return false
}
for i := 0; i < expectedLen; i++ {
if !InEpsilon(t, expectedSlice.Index(i).Interface(), actualSlice.Index(i).Interface(), epsilon, "at index %d", i) {
return false
} }
} }
@ -1870,23 +1920,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
} }
// FailNow panics. // FailNow panics.
func (c *CollectT) FailNow() { func (*CollectT) FailNow() {
panic("Assertion failed") panic("Assertion failed")
} }
// Reset clears the collected errors. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Reset() { func (*CollectT) Reset() {
c.errors = nil panic("Reset() is deprecated")
} }
// Copy copies the collected errors to the supplied t. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Copy(t TestingT) { func (*CollectT) Copy(TestingT) {
if tt, ok := t.(tHelper); ok { panic("Copy() is deprecated")
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
} }
// EventuallyWithT asserts that given condition will be met in waitFor time, // EventuallyWithT asserts that given condition will be met in waitFor time,
@ -1912,8 +1957,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper() h.Helper()
} }
collect := new(CollectT) var lastFinishedTickErrs []error
ch := make(chan bool, 1) ch := make(chan []error, 1)
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -1924,19 +1969,25 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; { for tick := ticker.C; ; {
select { select {
case <-timer.C: case <-timer.C:
collect.Copy(t) for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...) return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick: case <-tick:
tick = nil tick = nil
collect.Reset()
go func() { go func() {
collect := new(CollectT)
defer func() {
ch <- collect.errors
}()
condition(collect) condition(collect)
ch <- len(collect.errors) == 0
}() }()
case v := <-ch: case errs := <-ch:
if v { if len(errs) == 0 {
return true return true
} }
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C tick = ticker.C
} }
} }

View File

@ -12,7 +12,7 @@ import (
// an error if building a new request fails. // an error if building a new request fails.
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil) req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -32,12 +32,12 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
if !isSuccessCode { if !isSuccessCode {
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isSuccessCode return isSuccessCode
@ -54,12 +54,12 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
if !isRedirectCode { if !isRedirectCode {
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isRedirectCode return isRedirectCode
@ -76,12 +76,12 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isErrorCode := code >= http.StatusBadRequest isErrorCode := code >= http.StatusBadRequest
if !isErrorCode { if !isErrorCode {
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isErrorCode return isErrorCode
@ -98,12 +98,12 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
successful := code == statuscode successful := code == statuscode
if !successful { if !successful {
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...)
} }
return successful return successful
@ -113,7 +113,10 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
// empty string if building a new request fails. // empty string if building a new request fails.
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) if len(values) > 0 {
url += "?" + values.Encode()
}
req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return "" return ""
} }
@ -135,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if !contains { if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return contains return contains
@ -155,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if contains { if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return !contains return !contains

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package require package require
@ -235,7 +232,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
t.FailNow() t.FailNow()
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
@ -249,7 +246,7 @@ func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArg
t.FailNow() t.FailNow()
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -1546,6 +1543,32 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
t.FailNow() t.FailNow()
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplements(t, interfaceObject, object, msgAndArgs...) {
return
}
t.FailNow()
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplementsf(t, interfaceObject, object, msg, args...) {
return
}
t.FailNow()
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// assert.NotNil(t, err) // assert.NotNil(t, err)
@ -1658,10 +1681,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
t.FailNow() t.FailNow()
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1672,10 +1697,12 @@ func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...i
t.FailNow() t.FailNow()
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1880,10 +1907,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
t.FailNow() t.FailNow()
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1894,10 +1922,11 @@ func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...inte
t.FailNow() t.FailNow()
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View File

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package require package require
@ -190,7 +187,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
EqualExportedValuesf(a.t, expected, actual, msg, args...) EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
@ -201,7 +198,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
EqualValues(a.t, expected, actual, msgAndArgs...) EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1222,6 +1219,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
NotErrorIsf(a.t, err, target, msg, args...) NotErrorIsf(a.t, err, target, msg, args...)
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// a.NotNil(err) // a.NotNil(err)
@ -1310,10 +1327,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
NotSamef(a.t, expected, actual, msg, args...) NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1321,10 +1340,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
NotSubset(a.t, list, subset, msgAndArgs...) NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1484,10 +1505,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
Samef(a.t, expected, actual, msg, args...) Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1495,10 +1517,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
Subset(a.t, list, subset, msgAndArgs...) Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

6
vendor/modules.txt vendored
View File

@ -539,8 +539,8 @@ github.com/spf13/pflag
# github.com/stoewer/go-strcase v1.2.0 # github.com/stoewer/go-strcase v1.2.0
## explicit; go 1.11 ## explicit; go 1.11
github.com/stoewer/go-strcase github.com/stoewer/go-strcase
# github.com/stretchr/testify v1.8.4 # github.com/stretchr/testify v1.9.0
## explicit; go 1.20 ## explicit; go 1.17
github.com/stretchr/testify/assert github.com/stretchr/testify/assert
github.com/stretchr/testify/require github.com/stretchr/testify/require
# go.etcd.io/etcd/api/v3 v3.5.10 # go.etcd.io/etcd/api/v3 v3.5.10
@ -1576,7 +1576,7 @@ k8s.io/kubernetes/test/utils/kubeconfig
# k8s.io/mount-utils v0.29.2 => k8s.io/mount-utils v0.29.2 # k8s.io/mount-utils v0.29.2 => k8s.io/mount-utils v0.29.2
## explicit; go 1.21 ## explicit; go 1.21
k8s.io/mount-utils k8s.io/mount-utils
# k8s.io/pod-security-admission v0.0.0 => k8s.io/pod-security-admission v0.29.2 # k8s.io/pod-security-admission v0.29.2 => k8s.io/pod-security-admission v0.29.2
## explicit; go 1.21 ## explicit; go 1.21
k8s.io/pod-security-admission/api k8s.io/pod-security-admission/api
k8s.io/pod-security-admission/policy k8s.io/pod-security-admission/policy