rebase: update kubernetes to latest

updating the kubernetes release to the
latest in main go.mod

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna
2024-08-19 10:01:33 +02:00
committed by mergify[bot]
parent 63c4c05b35
commit 5a66991bb3
2173 changed files with 98906 additions and 61334 deletions

View File

@ -7,7 +7,9 @@ package(
go_library(
name = "go_default_library",
srcs = [
"bindings.go",
"encoders.go",
"formatting.go",
"guards.go",
"lists.go",
"math.go",
@ -21,14 +23,14 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
@ -61,7 +63,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",

View File

@ -414,3 +414,17 @@ Examples:
'TacoCat'.upperAscii() // returns 'TACOCAT'
'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
### Reverse
Returns a new string whose characters are the same as the target string, only formatted in
reverse order.
This function relies on converting strings to rune arrays in order to reverse.
It can be located in Version 3 of strings.
<string>.reverse() -> <string>
Examples:
'gums'.reverse() // returns 'smug'
'John Smith'.reverse() // returns 'htimS nhoJ'

View File

@ -16,8 +16,8 @@ package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
)
// Bindings returns a cel.EnvOption to configure support for local variable
@ -61,7 +61,7 @@ func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.NewReceiverMacro(bindMacro, 3, celBind),
cel.ReceiverMacro(bindMacro, 3, celBind),
),
}
}
@ -70,27 +70,27 @@ func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers")
return nil, mef.NewError(varIdent.ID(), "cel.bind() variable names must be simple identifiers")
}
varInit := args[1]
resultExpr := args[2]
return meh.Fold(
return mef.NewComprehension(
mef.NewList(),
unusedIterVar,
meh.NewList(),
varName,
varInit,
meh.LiteralBool(false),
meh.Ident(varName),
mef.NewLiteral(types.False),
mef.NewIdent(varName),
resultExpr,
), nil
}

904
vendor/github.com/google/cel-go/ext/formatting.go generated vendored Normal file
View File

@ -0,0 +1,904 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"errors"
"fmt"
"math"
"sort"
"strconv"
"strings"
"unicode"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner suitable
// for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
//
// It is used to implement the %s specifier in the (string).format() extension function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", stringFormatError(runtimeID, arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", decimalFormatError(runtimeID, arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fixedPointFormatError(runtimeID, arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", scientificFormatError(runtimeID, arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%v\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", binaryFormatError(runtimeID, arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", hexFormatError(runtimeID, arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", octalFormatError(runtimeID, arg.Type().TypeName())
}
}
// stringFormatValidator implements the cel.ASTValidator interface allowing for static validation
// of string.format calls.
type stringFormatValidator struct{}
// Name returns the name of the validator.
func (stringFormatValidator) Name() string {
return "cel.lib.ext.validate.functions.string.format"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
// during homogeneous aggregate literal type-checks.
func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error {
functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
functions = append(functions, "format")
return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
}
// Validate parses all literal format strings and type checks the format clause against the argument
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
root := ast.NavigateAST(a)
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
for _, e := range formatCallExprs {
call := e.AsCall()
formatStr := call.Target().AsLiteral().Value().(string)
args := call.Args()[0].AsList().Elements()
formatCheck := &stringFormatChecker{
args: args,
ast: a,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), err.Error())
continue
}
seenArgs := formatCheck.argsRequested
if len(args) > seenArgs {
iss.ReportErrorAtID(e.ID(),
"too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args))
}
}
}
// getErrorExprID determines which list literal argument triggered a type-disagreement for the
// purposes of more accurate error message reports.
func getErrorExprID(id int64, err error) int64 {
fmtErr, ok := err.(formatError)
if ok {
return fmtErr.id
}
wrapped := errors.Unwrap(err)
if wrapped != nil {
return getErrorExprID(id, wrapped)
}
return id
}
// matchConstantFormatStringWithListLiteralArgs matches all valid expression nodes for string
// format checking.
func matchConstantFormatStringWithListLiteralArgs(a *ast.AST) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() != ast.CallKind {
return false
}
call := e.AsCall()
if !call.IsMemberFunction() || call.FunctionName() != "format" {
return false
}
overloadIDs := a.GetOverloadIDs(e.ID())
if len(overloadIDs) != 0 {
found := false
for _, overload := range overloadIDs {
if overload == overloads.ExtFormatString {
found = true
break
}
}
if !found {
return false
}
}
formatString := call.Target()
if formatString.Kind() != ast.LiteralKind && formatString.AsLiteral().Type() != cel.StringType {
return false
}
args := call.Args()
if len(args) != 1 {
return false
}
formatArgs := args[0]
return formatArgs.Kind() == ast.ListKind
}
}
// stringFormatChecker implements the formatStringInterpolater interface
type stringFormatChecker struct {
args []ast.Expr
argsRequested int
currArgIndex int64
ast *ast.AST
}
func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) {
formatArg := c.args[c.currArgIndex]
valid, badID := c.verifyString(formatArg)
if !valid {
return "", stringFormatError(badID, c.typeOf(badID).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
if !valid {
return "", decimalFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
if !valid {
return "", fixedPointFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
if !valid {
return "", scientificFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType)
if !valid {
return "", binaryFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if !valid {
return "", hexFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
if !valid {
return "", octalFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.currArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *stringFormatChecker) Size() int64 {
return int64(len(c.args))
}
func (c *stringFormatChecker) typeOf(id int64) *cel.Type {
return c.ast.GetType(id)
}
func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool {
t := c.typeOf(id)
if t == cel.DynType {
return true
}
for _, vt := range validTypes {
// Only check runtime type compatibility without delving deeper into parameterized types
if t.Kind() == vt.Kind() {
return true
}
}
return false
}
func (c *stringFormatChecker) verifyString(sub ast.Expr) (bool, int64) {
paramA := cel.TypeParamType("A")
paramB := cel.TypeParamType("B")
subVerified := c.verifyTypeOneOf(sub.ID(),
cel.ListType(paramA), cel.MapType(paramA, paramB),
cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType,
cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType)
if !subVerified {
return false, sub.ID()
}
switch sub.Kind() {
case ast.ListKind:
for _, e := range sub.AsList().Elements() {
// recursively verify if we're dealing with a list/map
verified, id := c.verifyString(e)
if !verified {
return false, id
}
}
return true, sub.ID()
case ast.MapKind:
for _, e := range sub.AsMap().Entries() {
// recursively verify if we're dealing with a list/map
entry := e.AsMapEntry()
verified, id := c.verifyString(entry.Key())
if !verified {
return false, id
}
verified, id = c.verifyString(entry.Value())
if !verified {
return false, id
}
}
return true, sub.ID()
default:
return true, sub.ID()
}
}
// helper routines for reporting common errors during string formatting static validation and
// runtime execution.
func binaryFormatError(id int64, badType string) error {
return newFormatError(id, "only integers and bools can be formatted as binary, was given %s", badType)
}
func decimalFormatError(id int64, badType string) error {
return newFormatError(id, "decimal clause can only be used on integers, was given %s", badType)
}
func fixedPointFormatError(id int64, badType string) error {
return newFormatError(id, "fixed-point clause can only be used on doubles, was given %s", badType)
}
func hexFormatError(id int64, badType string) error {
return newFormatError(id, "only integers, byte buffers, and strings can be formatted as hex, was given %s", badType)
}
func octalFormatError(id int64, badType string) error {
return newFormatError(id, "octal clause can only be used on integers, was given %s", badType)
}
func scientificFormatError(id int64, badType string) error {
return newFormatError(id, "scientific clause can only be used on doubles, was given %s", badType)
}
func stringFormatError(id int64, badType string) error {
return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType)
}
type formatError struct {
id int64
msg string
}
func newFormatError(id int64, msg string, args ...any) error {
return formatError{
id: id,
msg: fmt.Sprintf(msg, args...),
}
}
func (e formatError) Error() string {
return e.msg
}
func (e formatError) Is(target error) bool {
return e.msg == target.Error()
}
// stringArgList implements the formatListArgs interface.
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) Size() int64 {
return c.args.Size().Value().(int64)
}
// formatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type formatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// formatListArgs is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type formatListArgs interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// Size returns the length of the argument list.
Size() int64
}
// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.Size() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", newParseFormatError("could not parse formatting clause", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", newParseFormatError("error during formatting", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback formatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}
type parseFormatError struct {
msg string
wrapped error
}
func newParseFormatError(msg string, wrapped error) error {
return parseFormatError{msg: msg, wrapped: wrapped}
}
func (e parseFormatError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error())
}
func (e parseFormatError) Is(target error) bool {
return e.Error() == target.Error()
}
func (e parseFormatError) Unwrap() error {
return e.wrapped
}
const (
runtimeID = int64(-1)
)

View File

@ -15,10 +15,9 @@
package ext
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// function invocation guards for common call signatures within extension functions.
@ -51,10 +50,10 @@ func listStringOrError(strs []string, err error) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(strs)
}
func macroTargetMatchesNamespace(ns string, target *exprpb.Expr) bool {
switch target.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
if target.GetIdentExpr().GetName() != ns {
func macroTargetMatchesNamespace(ns string, target ast.Expr) bool {
switch target.Kind() {
case ast.IdentKind:
if target.AsIdent() != ns {
return false
}
return true

View File

@ -19,11 +19,10 @@ import (
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Math returns a cel.EnvOption to configure namespaced math helper macros and
@ -111,9 +110,9 @@ func (mathLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// math.least(num, ...)
cel.NewReceiverVarArgMacro(leastMacro, mathLeast),
cel.ReceiverVarArgMacro(leastMacro, mathLeast),
// math.greatest(num, ...)
cel.NewReceiverVarArgMacro(greatestMacro, mathGreatest),
cel.ReceiverVarArgMacro(greatestMacro, mathGreatest),
),
cel.Function(minFunc,
cel.Overload("math_@min_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
@ -187,57 +186,57 @@ func (mathLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func mathLeast(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument")
return nil, meh.NewError(target.ID(), "math.least() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(minFunc, args[0]), nil
return meh.NewCall(minFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value")
return nil, meh.NewError(args[0].ID(), "math.least() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, args...), nil
return meh.NewCall(minFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, meh.NewList(args...)), nil
return meh.NewCall(minFunc, meh.NewList(args...)), nil
}
}
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func mathGreatest(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument")
return nil, mef.NewError(target.ID(), "math.greatest() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(maxFunc, args[0]), nil
return mef.NewCall(maxFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value")
return nil, mef.NewError(args[0].ID(), "math.greatest() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.greatest()", args)
err := checkInvalidArgs(mef, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, args...), nil
return mef.NewCall(maxFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.greatest()", args)
err := checkInvalidArgs(mef, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, meh.NewList(args...)), nil
return mef.NewCall(maxFunc, mef.NewList(args...)), nil
}
}
@ -311,48 +310,48 @@ func maxList(numList ref.Val) ref.Val {
}
}
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error {
func checkInvalidArgs(meh cel.MacroExprFactory, funcName string, args []ast.Expr) *cel.Error {
for _, arg := range args {
err := checkInvalidArgLiteral(funcName, arg)
if err != nil {
return meh.NewError(arg.GetId(), err.Error())
return meh.NewError(arg.ID(), err.Error())
}
}
return nil
}
func checkInvalidArgLiteral(funcName string, arg *exprpb.Expr) error {
func checkInvalidArgLiteral(funcName string, arg ast.Expr) error {
if !isValidArgType(arg) {
return fmt.Errorf("%s simple literal arguments must be numeric", funcName)
}
return nil
}
func isValidArgType(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
c := arg.GetConstExpr()
switch c.GetConstantKind().(type) {
case *exprpb.Constant_DoubleValue, *exprpb.Constant_Int64Value, *exprpb.Constant_Uint64Value:
func isValidArgType(arg ast.Expr) bool {
switch arg.Kind() {
case ast.LiteralKind:
c := ref.Val(arg.AsLiteral())
switch c.(type) {
case types.Double, types.Int, types.Uint:
return true
default:
return false
}
case *exprpb.Expr_ListExpr, *exprpb.Expr_StructExpr:
case ast.ListKind, ast.MapKind, ast.StructKind:
return false
default:
return true
}
}
func isListLiteralWithValidArgs(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ListExpr:
list := arg.GetListExpr()
if len(list.GetElements()) == 0 {
func isListLiteralWithValidArgs(arg ast.Expr) bool {
switch arg.Kind() {
case ast.ListKind:
list := arg.AsList()
if list.Size() == 0 {
return false
}
for _, e := range list.GetElements() {
for _, e := range list.Elements() {
if !isValidArgType(e) {
return false
}

View File

@ -96,17 +96,21 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
t, err := newNativeType(rt)
result, err := newNativeTypes(rt)
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
t, err := newNativeType(rt.Type())
result, err := newNativeTypes(rt.Type())
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
default:
return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt)
}
@ -151,6 +155,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}
// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
if t, found := tp.nativeTypes[typeName]; found {
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
}
return fields, true
}
if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found {
return celTypeFields, true
}
return tp.baseProvider.FindStructFieldNames(typeName)
}
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
@ -447,6 +469,42 @@ func (o *nativeObj) Value() any {
return o.val
}
func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(rawType)
if err != nil {
return nil, err
}
result := []*nativeType{nt}
alreadySeen := make(map[string]struct{})
var iterateStructMembers func(reflect.Type)
iterateStructMembers = func(t reflect.Type) {
if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return
}
if _, seen := alreadySeen[t.String()]; seen {
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(t)
if ntErr != nil {
err = ntErr
return
}
result = append(result, nt)
for idx := 0; idx < t.NumField(); idx++ {
iterateStructMembers(t.Field(idx).Type)
}
}
iterateStructMembers(rawType)
return result, err
}
func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {

View File

@ -16,8 +16,7 @@ package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)
// Protos returns a cel.EnvOption to configure extended macros and functions for
@ -72,9 +71,9 @@ func (protoLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// proto.getExt(msg, select_expression)
cel.NewReceiverMacro(getExtension, 2, getProtoExt),
cel.ReceiverMacro(getExtension, 2, getProtoExt),
// proto.hasExt(msg, select_expression)
cel.NewReceiverMacro(hasExtension, 2, hasProtoExt),
cel.ReceiverMacro(hasExtension, 2, hasProtoExt),
),
}
}
@ -85,56 +84,56 @@ func (protoLib) ProgramOptions() []cel.ProgramOption {
}
// hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message.
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func hasProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extensionField, err := getExtFieldName(meh, args[1])
extensionField, err := getExtFieldName(mef, args[1])
if err != nil {
return nil, err
}
return meh.PresenceTest(args[0], extensionField), nil
return mef.NewPresenceTest(args[0], extensionField), nil
}
// getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message.
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func getProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extFieldName, err := getExtFieldName(meh, args[1])
extFieldName, err := getExtFieldName(mef, args[1])
if err != nil {
return nil, err
}
return meh.Select(args[0], extFieldName), nil
return mef.NewSelect(args[0], extFieldName), nil
}
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) {
func getExtFieldName(mef cel.MacroExprFactory, expr ast.Expr) (string, *cel.Error) {
isValid := false
extensionField := ""
switch expr.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
switch expr.Kind() {
case ast.SelectKind:
extensionField, isValid = validateIdentifier(expr)
}
if !isValid {
return "", meh.NewError(expr.GetId(), "invalid extension field")
return "", mef.NewError(expr.ID(), "invalid extension field")
}
return extensionField, nil
}
func validateIdentifier(expr *exprpb.Expr) (string, bool) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
return expr.GetIdentExpr().GetName(), true
case *exprpb.Expr_SelectExpr:
sel := expr.GetSelectExpr()
if sel.GetTestOnly() {
func validateIdentifier(expr ast.Expr) (string, bool) {
switch expr.Kind() {
case ast.IdentKind:
return expr.AsIdent(), true
case ast.SelectKind:
sel := expr.AsSelect()
if sel.IsTestOnly() {
return "", false
}
opStr, isIdent := validateIdentifier(sel.GetOperand())
opStr, isIdent := validateIdentifier(sel.Operand())
if !isIdent {
return "", false
}
return opStr + "." + sel.GetField(), true
return opStr + "." + sel.FieldName(), true
default:
return "", false
}

View File

@ -19,6 +19,8 @@ import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -119,6 +121,68 @@ func (setsLib) ProgramOptions() []cel.ProgramOption {
}
}
// NewSetMembershipOptimizer rewrites set membership tests using the `in` operator against a list
// of constant values of enum, int, uint, string, or boolean type into a set membership test against
// a map where the map keys are the elements of the list.
func NewSetMembershipOptimizer() (cel.ASTOptimizer, error) {
return setsLib{}, nil
}
func (setsLib) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
matches := ast.MatchDescendants(root, matchInConstantList(a))
for _, match := range matches {
call := match.AsCall()
listArg := call.Args()[1]
entries := make([]ast.EntryExpr, len(listArg.AsList().Elements()))
for i, elem := range listArg.AsList().Elements() {
var entry ast.EntryExpr
if r, found := a.ReferenceMap()[elem.ID()]; found && r.Value != nil {
entry = ctx.NewMapEntry(ctx.NewLiteral(r.Value), ctx.NewLiteral(types.True), false)
} else {
entry = ctx.NewMapEntry(elem, ctx.NewLiteral(types.True), false)
}
entries[i] = entry
}
mapArg := ctx.NewMap(entries)
ctx.UpdateExpr(listArg, mapArg)
}
return a
}
func matchInConstantList(a *ast.AST) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() != ast.CallKind {
return false
}
call := e.AsCall()
if call.FunctionName() != operators.In {
return false
}
aggregateVal := call.Args()[1]
if aggregateVal.Kind() != ast.ListKind {
return false
}
listVal := aggregateVal.AsList()
for _, elem := range listVal.Elements() {
if r, found := a.ReferenceMap()[elem.ID()]; found {
if r.Value != nil {
continue
}
}
if elem.Kind() != ast.LiteralKind {
return false
}
lit := elem.AsLiteral()
if !(lit.Type() == cel.StringType || lit.Type() == cel.IntType ||
lit.Type() == cel.UintType || lit.Type() == cel.BoolType) {
return false
}
}
return true
}
}
func setsIntersects(listA, listB ref.Val) ref.Val {
lA := listA.(traits.Lister)
lB := listB.(traits.Lister)

View File

@ -21,19 +21,16 @@ import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
const (
@ -99,7 +96,7 @@ const (
// "a map inside a list: %s".format([[1, 2, 3, {"a": "x", "b": "y", "c": "z"}]]) // returns "a map inside a list: [1, 2, 3, {"a":"x", "b":"y", "c":"d"}]"
// "true bool: %s - false bool: %s\nbinary bool: %b".format([true, false, true]) // returns "true bool: true - false bool: false\nbinary bool: 1"
//
// Passing an incorrect type (an integer to `%s`) is considered an error, as well as attempting
// Passing an incorrect type (a string to `%b`) is considered an error, as well as attempting
// to use more formatting clauses than there are arguments (`%d %d %d` while passing two ints, for instance).
// If compile-time checking is enabled, and the formatting string is a constant, and the argument list is a literal,
// then letting any arguments go unused/unformatted is also considered an error.
@ -205,6 +202,8 @@ const (
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
// 'hello hello'.replace('', '_') // returns '_h_e_l_l_o_ _h_e_l_l_o_'
// 'hello hello'.replace('h', '') // returns 'ello ello'
//
// # Split
//
@ -270,8 +269,26 @@ const (
//
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
//
// # Reverse
//
// Introduced at version: 3
//
// Returns a new string whose characters are the same as the target string, only formatted in
// reverse order.
// This function relies on converting strings to rune arrays in order to reverse
//
// <string>.reverse() -> <string>
//
// Examples:
//
// 'gums'.reverse() // returns 'smug'
// 'John Smith'.reverse() // returns 'htimS nhoJ'
func Strings(options ...StringsOption) cel.EnvOption {
s := &stringLib{version: math.MaxUint32}
s := &stringLib{
version: math.MaxUint32,
validateFormat: true,
}
for _, o := range options {
s = o(s)
}
@ -279,8 +296,9 @@ func Strings(options ...StringsOption) cel.EnvOption {
}
type stringLib struct {
locale string
version uint32
locale string
version uint32
validateFormat bool
}
// LibraryName implements the SingletonLibrary interface method.
@ -317,6 +335,17 @@ func StringsVersion(version uint32) StringsOption {
}
}
// StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have
// valid formatting clauses and valid argument types for each clause.
//
// Enabled by default.
func StringsValidateFormatCalls(value bool) StringsOption {
return func(s *stringLib) *stringLib {
s.validateFormat = value
return s
}
}
// CompileOptions implements the Library interface method.
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
@ -440,13 +469,15 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(interpreter.ParseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))),
cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(quote(string(s)))
}))))
}))),
cel.ASTValidators(stringFormatValidator{}))
}
if lib.version >= 2 {
@ -471,7 +502,7 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
return types.WrapErr(err)
}
return stringOrError(join(l.([]string)))
})),
@ -479,13 +510,26 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
return types.WrapErr(err)
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
)
}
if lib.version >= 3 {
opts = append(opts,
cel.Function("reverse",
cel.MemberOverload("reverse", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(reverse(string(s)))
}))),
)
}
if lib.validateFormat {
opts = append(opts, cel.ASTValidators(stringFormatValidator{}))
}
return opts
}
@ -636,6 +680,14 @@ func upperASCII(str string) (string, error) {
return string(runes), nil
}
func reverse(str string) (string, error) {
chars := []rune(str)
for i, j := 0, len(chars)-1; i < j; i, j = i+1, j-1 {
chars[i], chars[j] = chars[j], chars[i]
}
return string(chars), nil
}
func joinSeparator(strs []string, separator string) (string, error) {
return strings.Join(strs, separator), nil
}
@ -661,238 +713,6 @@ func joinValSeparator(strs traits.Lister, separator string) (string, error) {
return sb.String(), nil
}
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner
// suitable for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
// It is used to implement the %s specifier in the (string).format() extension
// function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", fmt.Errorf("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", fmt.Errorf("decimal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
// quote implements a string quoting function. The string will be wrapped in
// double quotes, and all valid CEL escape sequences will be escaped to show up
// literally if printed. If the input contains any invalid UTF-8, the invalid runes
@ -940,156 +760,6 @@ func sanitize(s string) string {
return sanitizedStringBuilder.String()
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("fixed-point clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("scientific clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", fmt.Errorf("only integers and bools can be formatted as binary, was given %s", arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", fmt.Errorf("only integers, byte buffers, and strings can be formatted as hex, was given %s", arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", fmt.Errorf("octal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) ArgSize() int64 {
return c.args.Size().Value().(int64)
}
var (
stringListType = reflect.TypeOf([]string{})
)