2021-11-08 20:26:59 +00:00
package matchers
import (
"fmt"
"reflect"
"strings"
"github.com/onsi/gomega/format"
)
2022-08-19 16:42:45 +00:00
// missingFieldError represents a missing field extraction error that
// HaveExistingFieldMatcher can ignore, as opposed to other, sever field
// extraction errors, such as nil pointers, et cetera.
type missingFieldError string
func ( e missingFieldError ) Error ( ) string {
return string ( e )
}
2024-11-04 20:10:45 +00:00
func extractField ( actual interface { } , field string , matchername string ) ( any , error ) {
2021-11-08 20:26:59 +00:00
fields := strings . SplitN ( field , "." , 2 )
actualValue := reflect . ValueOf ( actual )
2022-02-07 20:17:55 +00:00
if actualValue . Kind ( ) == reflect . Ptr {
actualValue = actualValue . Elem ( )
}
if actualValue == ( reflect . Value { } ) {
2022-08-19 16:42:45 +00:00
return nil , fmt . Errorf ( "%s encountered nil while dereferencing a pointer of type %T." , matchername , actual )
2022-02-07 20:17:55 +00:00
}
2021-11-08 20:26:59 +00:00
if actualValue . Kind ( ) != reflect . Struct {
2022-08-19 16:42:45 +00:00
return nil , fmt . Errorf ( "%s encountered:\n%s\nWhich is not a struct." , matchername , format . Object ( actual , 1 ) )
2021-11-08 20:26:59 +00:00
}
var extractedValue reflect . Value
if strings . HasSuffix ( fields [ 0 ] , "()" ) {
extractedValue = actualValue . MethodByName ( strings . TrimSuffix ( fields [ 0 ] , "()" ) )
2022-08-19 16:42:45 +00:00
if extractedValue == ( reflect . Value { } ) && actualValue . CanAddr ( ) {
extractedValue = actualValue . Addr ( ) . MethodByName ( strings . TrimSuffix ( fields [ 0 ] , "()" ) )
}
2021-11-08 20:26:59 +00:00
if extractedValue == ( reflect . Value { } ) {
2024-12-16 20:47:08 +00:00
ptr := reflect . New ( actualValue . Type ( ) )
ptr . Elem ( ) . Set ( actualValue )
extractedValue = ptr . MethodByName ( strings . TrimSuffix ( fields [ 0 ] , "()" ) )
if extractedValue == ( reflect . Value { } ) {
return nil , missingFieldError ( fmt . Sprintf ( "%s could not find method named '%s' in struct of type %T." , matchername , fields [ 0 ] , actual ) )
}
2021-11-08 20:26:59 +00:00
}
t := extractedValue . Type ( )
if t . NumIn ( ) != 0 || t . NumOut ( ) != 1 {
2022-08-19 16:42:45 +00:00
return nil , fmt . Errorf ( "%s found an invalid method named '%s' in struct of type %T.\nMethods must take no arguments and return exactly one value." , matchername , fields [ 0 ] , actual )
2021-11-08 20:26:59 +00:00
}
extractedValue = extractedValue . Call ( [ ] reflect . Value { } ) [ 0 ]
} else {
extractedValue = actualValue . FieldByName ( fields [ 0 ] )
if extractedValue == ( reflect . Value { } ) {
2022-08-19 16:42:45 +00:00
return nil , missingFieldError ( fmt . Sprintf ( "%s could not find field named '%s' in struct:\n%s" , matchername , fields [ 0 ] , format . Object ( actual , 1 ) ) )
2021-11-08 20:26:59 +00:00
}
}
if len ( fields ) == 1 {
return extractedValue . Interface ( ) , nil
} else {
2022-08-19 16:42:45 +00:00
return extractField ( extractedValue . Interface ( ) , fields [ 1 ] , matchername )
2021-11-08 20:26:59 +00:00
}
}
type HaveFieldMatcher struct {
Field string
Expected interface { }
2024-11-04 20:10:45 +00:00
}
2021-11-08 20:26:59 +00:00
2024-11-04 20:10:45 +00:00
func ( matcher * HaveFieldMatcher ) expectedMatcher ( ) omegaMatcher {
var isMatcher bool
expectedMatcher , isMatcher := matcher . Expected . ( omegaMatcher )
if ! isMatcher {
expectedMatcher = & EqualMatcher { Expected : matcher . Expected }
}
return expectedMatcher
2021-11-08 20:26:59 +00:00
}
func ( matcher * HaveFieldMatcher ) Match ( actual interface { } ) ( success bool , err error ) {
2024-11-04 20:10:45 +00:00
extractedField , err := extractField ( actual , matcher . Field , "HaveField" )
2021-11-08 20:26:59 +00:00
if err != nil {
return false , err
}
2024-11-04 20:10:45 +00:00
return matcher . expectedMatcher ( ) . Match ( extractedField )
2021-11-08 20:26:59 +00:00
}
func ( matcher * HaveFieldMatcher ) FailureMessage ( actual interface { } ) ( message string ) {
2024-11-04 20:10:45 +00:00
extractedField , err := extractField ( actual , matcher . Field , "HaveField" )
if err != nil {
// this really shouldn't happen
return fmt . Sprintf ( "Failed to extract field '%s': %s" , matcher . Field , err )
}
2021-11-08 20:26:59 +00:00
message = fmt . Sprintf ( "Value for field '%s' failed to satisfy matcher.\n" , matcher . Field )
2024-11-04 20:10:45 +00:00
message += matcher . expectedMatcher ( ) . FailureMessage ( extractedField )
2021-11-08 20:26:59 +00:00
return message
}
func ( matcher * HaveFieldMatcher ) NegatedFailureMessage ( actual interface { } ) ( message string ) {
2024-11-04 20:10:45 +00:00
extractedField , err := extractField ( actual , matcher . Field , "HaveField" )
if err != nil {
// this really shouldn't happen
return fmt . Sprintf ( "Failed to extract field '%s': %s" , matcher . Field , err )
}
2021-11-08 20:26:59 +00:00
message = fmt . Sprintf ( "Value for field '%s' satisfied matcher, but should not have.\n" , matcher . Field )
2024-11-04 20:10:45 +00:00
message += matcher . expectedMatcher ( ) . NegatedFailureMessage ( extractedField )
2021-11-08 20:26:59 +00:00
return message
}