2019-05-31 09:45:11 +00:00
|
|
|
package matchers
|
|
|
|
|
|
|
|
import (
|
2021-06-04 08:59:18 +00:00
|
|
|
"errors"
|
2019-05-31 09:45:11 +00:00
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
|
|
|
|
"github.com/onsi/gomega/format"
|
|
|
|
)
|
|
|
|
|
|
|
|
type MatchErrorMatcher struct {
|
2023-10-26 09:28:31 +00:00
|
|
|
Expected any
|
|
|
|
FuncErrDescription []any
|
|
|
|
isFunc bool
|
2019-05-31 09:45:11 +00:00
|
|
|
}
|
|
|
|
|
2023-10-26 09:28:31 +00:00
|
|
|
func (matcher *MatchErrorMatcher) Match(actual any) (success bool, err error) {
|
|
|
|
matcher.isFunc = false
|
|
|
|
|
2019-05-31 09:45:11 +00:00
|
|
|
if isNil(actual) {
|
|
|
|
return false, fmt.Errorf("Expected an error, got nil")
|
|
|
|
}
|
|
|
|
|
|
|
|
if !isError(actual) {
|
|
|
|
return false, fmt.Errorf("Expected an error. Got:\n%s", format.Object(actual, 1))
|
|
|
|
}
|
|
|
|
|
|
|
|
actualErr := actual.(error)
|
2020-01-14 10:38:55 +00:00
|
|
|
expected := matcher.Expected
|
2019-05-31 09:45:11 +00:00
|
|
|
|
2020-01-14 10:38:55 +00:00
|
|
|
if isError(expected) {
|
2023-01-30 20:05:01 +00:00
|
|
|
// first try the built-in errors.Is
|
|
|
|
if errors.Is(actualErr, expected.(error)) {
|
|
|
|
return true, nil
|
|
|
|
}
|
|
|
|
// if not, try DeepEqual along the error chain
|
|
|
|
for unwrapped := actualErr; unwrapped != nil; unwrapped = errors.Unwrap(unwrapped) {
|
|
|
|
if reflect.DeepEqual(unwrapped, expected) {
|
|
|
|
return true, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false, nil
|
2019-05-31 09:45:11 +00:00
|
|
|
}
|
|
|
|
|
2020-01-14 10:38:55 +00:00
|
|
|
if isString(expected) {
|
|
|
|
return actualErr.Error() == expected, nil
|
2019-05-31 09:45:11 +00:00
|
|
|
}
|
|
|
|
|
2023-10-26 09:28:31 +00:00
|
|
|
v := reflect.ValueOf(expected)
|
|
|
|
t := v.Type()
|
|
|
|
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
|
|
|
|
if t.Kind() == reflect.Func && t.NumIn() == 1 && t.In(0).Implements(errorInterface) && t.NumOut() == 1 && t.Out(0).Kind() == reflect.Bool {
|
|
|
|
if len(matcher.FuncErrDescription) == 0 {
|
|
|
|
return false, fmt.Errorf("MatchError requires an additional description when passed a function")
|
|
|
|
}
|
|
|
|
matcher.isFunc = true
|
|
|
|
return v.Call([]reflect.Value{reflect.ValueOf(actualErr)})[0].Bool(), nil
|
|
|
|
}
|
|
|
|
|
2019-05-31 09:45:11 +00:00
|
|
|
var subMatcher omegaMatcher
|
|
|
|
var hasSubMatcher bool
|
2020-01-14 10:38:55 +00:00
|
|
|
if expected != nil {
|
|
|
|
subMatcher, hasSubMatcher = (expected).(omegaMatcher)
|
2019-05-31 09:45:11 +00:00
|
|
|
if hasSubMatcher {
|
|
|
|
return subMatcher.Match(actualErr.Error())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-01-14 10:38:55 +00:00
|
|
|
return false, fmt.Errorf(
|
|
|
|
"MatchError must be passed an error, a string, or a Matcher that can match on strings. Got:\n%s",
|
|
|
|
format.Object(expected, 1))
|
2019-05-31 09:45:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (matcher *MatchErrorMatcher) FailureMessage(actual interface{}) (message string) {
|
2023-10-26 09:28:31 +00:00
|
|
|
if matcher.isFunc {
|
|
|
|
return format.Message(actual, fmt.Sprintf("to match error function %s", matcher.FuncErrDescription[0]))
|
|
|
|
}
|
2019-05-31 09:45:11 +00:00
|
|
|
return format.Message(actual, "to match error", matcher.Expected)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (matcher *MatchErrorMatcher) NegatedFailureMessage(actual interface{}) (message string) {
|
2023-10-26 09:28:31 +00:00
|
|
|
if matcher.isFunc {
|
|
|
|
return format.Message(actual, fmt.Sprintf("not to match error function %s", matcher.FuncErrDescription[0]))
|
|
|
|
}
|
2019-05-31 09:45:11 +00:00
|
|
|
return format.Message(actual, "not to match error", matcher.Expected)
|
|
|
|
}
|