mirror of
https://github.com/ceph/ceph-csi.git
synced 2024-11-15 02:40:23 +00:00
5a66991bb3
updating the kubernetes release to the latest in main go.mod Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
905 lines
28 KiB
Go
905 lines
28 KiB
Go
// Copyright 2023 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package ext
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"golang.org/x/text/language"
|
|
"golang.org/x/text/message"
|
|
|
|
"github.com/google/cel-go/cel"
|
|
"github.com/google/cel-go/common/ast"
|
|
"github.com/google/cel-go/common/overloads"
|
|
"github.com/google/cel-go/common/types"
|
|
"github.com/google/cel-go/common/types/ref"
|
|
"github.com/google/cel-go/common/types/traits"
|
|
)
|
|
|
|
type clauseImpl func(ref.Val, string) (string, error)
|
|
|
|
func clauseForType(argType ref.Type) (clauseImpl, error) {
|
|
switch argType {
|
|
case types.IntType, types.UintType:
|
|
return formatDecimal, nil
|
|
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
|
|
return FormatString, nil
|
|
case types.TimestampType, types.DurationType:
|
|
// special case to ensure timestamps/durations get printed as CEL literals
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
argStrVal := arg.ConvertToType(types.StringType)
|
|
argStr := argStrVal.Value().(string)
|
|
if arg.Type() == types.TimestampType {
|
|
return fmt.Sprintf("timestamp(%q)", argStr), nil
|
|
}
|
|
if arg.Type() == types.DurationType {
|
|
return fmt.Sprintf("duration(%q)", argStr), nil
|
|
}
|
|
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
|
|
}, nil
|
|
case types.ListType:
|
|
return formatList, nil
|
|
case types.MapType:
|
|
return formatMap, nil
|
|
case types.DoubleType:
|
|
// avoid formatFixed so we can output a period as the decimal separator in order
|
|
// to always be a valid CEL literal
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
argDouble, ok := arg.Value().(float64)
|
|
if !ok {
|
|
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
|
|
}
|
|
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
|
|
return fmt.Sprintf(fmtStr, argDouble), nil
|
|
}, nil
|
|
case types.TypeType:
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
|
|
}
|
|
}
|
|
|
|
func formatList(arg ref.Val, locale string) (string, error) {
|
|
argList := arg.(traits.Lister)
|
|
argIterator := argList.Iterator()
|
|
var listStrBuilder strings.Builder
|
|
_, err := listStrBuilder.WriteRune('[')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to list string: %w", err)
|
|
}
|
|
for argIterator.HasNext() == types.True {
|
|
member := argIterator.Next()
|
|
memberFormat, err := clauseForType(member.Type())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
unquotedStr, err := memberFormat(member, locale)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
str := quoteForCEL(member, unquotedStr)
|
|
_, err = listStrBuilder.WriteString(str)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to list string: %w", err)
|
|
}
|
|
if argIterator.HasNext() == types.True {
|
|
_, err = listStrBuilder.WriteString(", ")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to list string: %w", err)
|
|
}
|
|
}
|
|
}
|
|
_, err = listStrBuilder.WriteRune(']')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to list string: %w", err)
|
|
}
|
|
return listStrBuilder.String(), nil
|
|
}
|
|
|
|
func formatMap(arg ref.Val, locale string) (string, error) {
|
|
argMap := arg.(traits.Mapper)
|
|
argIterator := argMap.Iterator()
|
|
type mapPair struct {
|
|
key string
|
|
value string
|
|
}
|
|
argPairs := make([]mapPair, argMap.Size().Value().(int64))
|
|
i := 0
|
|
for argIterator.HasNext() == types.True {
|
|
key := argIterator.Next()
|
|
var keyFormat clauseImpl
|
|
switch key.Type() {
|
|
case types.StringType, types.BoolType:
|
|
keyFormat = FormatString
|
|
case types.IntType, types.UintType:
|
|
keyFormat = formatDecimal
|
|
default:
|
|
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
|
|
}
|
|
unquotedKeyStr, err := keyFormat(key, locale)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
keyStr := quoteForCEL(key, unquotedKeyStr)
|
|
value, found := argMap.Find(key)
|
|
if !found {
|
|
return "", fmt.Errorf("could not find key: %q", key)
|
|
}
|
|
valueFormat, err := clauseForType(value.Type())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
unquotedValueStr, err := valueFormat(value, locale)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
valueStr := quoteForCEL(value, unquotedValueStr)
|
|
argPairs[i] = mapPair{keyStr, valueStr}
|
|
i++
|
|
}
|
|
sort.SliceStable(argPairs, func(x, y int) bool {
|
|
return argPairs[x].key < argPairs[y].key
|
|
})
|
|
var mapStrBuilder strings.Builder
|
|
_, err := mapStrBuilder.WriteRune('{')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to map string: %w", err)
|
|
}
|
|
for i, entry := range argPairs {
|
|
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to map string: %w", err)
|
|
}
|
|
if i < len(argPairs)-1 {
|
|
_, err = mapStrBuilder.WriteString(", ")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to map string: %w", err)
|
|
}
|
|
}
|
|
}
|
|
_, err = mapStrBuilder.WriteRune('}')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing to map string: %w", err)
|
|
}
|
|
return mapStrBuilder.String(), nil
|
|
}
|
|
|
|
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner suitable
|
|
// for embedding directly in CEL.
|
|
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
|
|
switch refVal.Type() {
|
|
case types.StringType:
|
|
return fmt.Sprintf("%q", unquotedValue)
|
|
case types.BytesType:
|
|
return fmt.Sprintf("b%q", unquotedValue)
|
|
case types.DoubleType:
|
|
// special case to handle infinity/NaN
|
|
num := refVal.Value().(float64)
|
|
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
|
|
return fmt.Sprintf("%q", unquotedValue)
|
|
}
|
|
return unquotedValue
|
|
default:
|
|
return unquotedValue
|
|
}
|
|
}
|
|
|
|
// FormatString returns the string representation of a CEL value.
|
|
//
|
|
// It is used to implement the %s specifier in the (string).format() extension function.
|
|
func FormatString(arg ref.Val, locale string) (string, error) {
|
|
switch arg.Type() {
|
|
case types.ListType:
|
|
return formatList(arg, locale)
|
|
case types.MapType:
|
|
return formatMap(arg, locale)
|
|
case types.IntType, types.UintType, types.DoubleType,
|
|
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
|
|
argStrVal := arg.ConvertToType(types.StringType)
|
|
argStr, ok := argStrVal.Value().(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
|
|
}
|
|
return argStr, nil
|
|
case types.NullType:
|
|
return "null", nil
|
|
default:
|
|
return "", stringFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
}
|
|
|
|
func formatDecimal(arg ref.Val, locale string) (string, error) {
|
|
switch arg.Type() {
|
|
case types.IntType:
|
|
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
|
|
}
|
|
return fmt.Sprintf("%d", argInt), nil
|
|
case types.UintType:
|
|
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
|
|
}
|
|
return fmt.Sprintf("%d", argInt), nil
|
|
default:
|
|
return "", decimalFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
}
|
|
|
|
func matchLanguage(locale string) (language.Tag, error) {
|
|
matcher, err := makeMatcher(locale)
|
|
if err != nil {
|
|
return language.Und, err
|
|
}
|
|
tag, _ := language.MatchStrings(matcher, locale)
|
|
return tag, nil
|
|
}
|
|
|
|
func makeMatcher(locale string) (language.Matcher, error) {
|
|
tags := make([]language.Tag, 0)
|
|
tag, err := language.Parse(locale)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tags = append(tags, tag)
|
|
return language.NewMatcher(tags), nil
|
|
}
|
|
|
|
type stringFormatter struct{}
|
|
|
|
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
|
|
return FormatString(arg, locale)
|
|
}
|
|
|
|
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
|
|
return formatDecimal(arg, locale)
|
|
}
|
|
|
|
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
|
|
if precision == nil {
|
|
precision = new(int)
|
|
*precision = defaultPrecision
|
|
}
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
strException := false
|
|
if arg.Type() == types.StringType {
|
|
argStr := arg.Value().(string)
|
|
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
|
|
strException = true
|
|
}
|
|
}
|
|
if arg.Type() != types.DoubleType && !strException {
|
|
return "", fixedPointFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
argFloatVal := arg.ConvertToType(types.DoubleType)
|
|
argFloat, ok := argFloatVal.Value().(float64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
|
|
}
|
|
fmtStr := fmt.Sprintf("%%.%df", *precision)
|
|
|
|
matchedLocale, err := matchLanguage(locale)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error matching locale: %w", err)
|
|
}
|
|
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
|
|
if precision == nil {
|
|
precision = new(int)
|
|
*precision = defaultPrecision
|
|
}
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
strException := false
|
|
if arg.Type() == types.StringType {
|
|
argStr := arg.Value().(string)
|
|
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
|
|
strException = true
|
|
}
|
|
}
|
|
if arg.Type() != types.DoubleType && !strException {
|
|
return "", scientificFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
argFloatVal := arg.ConvertToType(types.DoubleType)
|
|
argFloat, ok := argFloatVal.Value().(float64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%v\" to float64", argFloatVal.Value())
|
|
}
|
|
matchedLocale, err := matchLanguage(locale)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error matching locale: %w", err)
|
|
}
|
|
fmtStr := fmt.Sprintf("%%%de", *precision)
|
|
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
|
|
switch arg.Type() {
|
|
case types.IntType:
|
|
argInt := arg.Value().(int64)
|
|
// locale is intentionally unused as integers formatted as binary
|
|
// strings are locale-independent
|
|
return fmt.Sprintf("%b", argInt), nil
|
|
case types.UintType:
|
|
argInt := arg.Value().(uint64)
|
|
return fmt.Sprintf("%b", argInt), nil
|
|
case types.BoolType:
|
|
argBool := arg.Value().(bool)
|
|
if argBool {
|
|
return "1", nil
|
|
}
|
|
return "0", nil
|
|
default:
|
|
return "", binaryFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
fmtStr := "%x"
|
|
if useUpper {
|
|
fmtStr = "%X"
|
|
}
|
|
switch arg.Type() {
|
|
case types.StringType, types.BytesType:
|
|
if arg.Type() == types.BytesType {
|
|
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
|
|
}
|
|
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
|
|
case types.IntType:
|
|
argInt, ok := arg.Value().(int64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
|
|
}
|
|
return fmt.Sprintf(fmtStr, argInt), nil
|
|
case types.UintType:
|
|
argInt, ok := arg.Value().(uint64)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
|
|
}
|
|
return fmt.Sprintf(fmtStr, argInt), nil
|
|
default:
|
|
return "", hexFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
|
|
switch arg.Type() {
|
|
case types.IntType:
|
|
argInt := arg.Value().(int64)
|
|
return fmt.Sprintf("%o", argInt), nil
|
|
case types.UintType:
|
|
argInt := arg.Value().(uint64)
|
|
return fmt.Sprintf("%o", argInt), nil
|
|
default:
|
|
return "", octalFormatError(runtimeID, arg.Type().TypeName())
|
|
}
|
|
}
|
|
|
|
// stringFormatValidator implements the cel.ASTValidator interface allowing for static validation
|
|
// of string.format calls.
|
|
type stringFormatValidator struct{}
|
|
|
|
// Name returns the name of the validator.
|
|
func (stringFormatValidator) Name() string {
|
|
return "cel.lib.ext.validate.functions.string.format"
|
|
}
|
|
|
|
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
|
|
// during homogeneous aggregate literal type-checks.
|
|
func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error {
|
|
functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
|
|
functions = append(functions, "format")
|
|
return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
|
|
}
|
|
|
|
// Validate parses all literal format strings and type checks the format clause against the argument
|
|
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
|
|
func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
|
|
root := ast.NavigateAST(a)
|
|
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
|
|
for _, e := range formatCallExprs {
|
|
call := e.AsCall()
|
|
formatStr := call.Target().AsLiteral().Value().(string)
|
|
args := call.Args()[0].AsList().Elements()
|
|
formatCheck := &stringFormatChecker{
|
|
args: args,
|
|
ast: a,
|
|
}
|
|
// use a placeholder locale, since locale doesn't affect syntax
|
|
_, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US")
|
|
if err != nil {
|
|
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), err.Error())
|
|
continue
|
|
}
|
|
seenArgs := formatCheck.argsRequested
|
|
if len(args) > seenArgs {
|
|
iss.ReportErrorAtID(e.ID(),
|
|
"too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args))
|
|
}
|
|
}
|
|
}
|
|
|
|
// getErrorExprID determines which list literal argument triggered a type-disagreement for the
|
|
// purposes of more accurate error message reports.
|
|
func getErrorExprID(id int64, err error) int64 {
|
|
fmtErr, ok := err.(formatError)
|
|
if ok {
|
|
return fmtErr.id
|
|
}
|
|
wrapped := errors.Unwrap(err)
|
|
if wrapped != nil {
|
|
return getErrorExprID(id, wrapped)
|
|
}
|
|
return id
|
|
}
|
|
|
|
// matchConstantFormatStringWithListLiteralArgs matches all valid expression nodes for string
|
|
// format checking.
|
|
func matchConstantFormatStringWithListLiteralArgs(a *ast.AST) ast.ExprMatcher {
|
|
return func(e ast.NavigableExpr) bool {
|
|
if e.Kind() != ast.CallKind {
|
|
return false
|
|
}
|
|
call := e.AsCall()
|
|
if !call.IsMemberFunction() || call.FunctionName() != "format" {
|
|
return false
|
|
}
|
|
overloadIDs := a.GetOverloadIDs(e.ID())
|
|
if len(overloadIDs) != 0 {
|
|
found := false
|
|
for _, overload := range overloadIDs {
|
|
if overload == overloads.ExtFormatString {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return false
|
|
}
|
|
}
|
|
formatString := call.Target()
|
|
if formatString.Kind() != ast.LiteralKind && formatString.AsLiteral().Type() != cel.StringType {
|
|
return false
|
|
}
|
|
args := call.Args()
|
|
if len(args) != 1 {
|
|
return false
|
|
}
|
|
formatArgs := args[0]
|
|
return formatArgs.Kind() == ast.ListKind
|
|
}
|
|
}
|
|
|
|
// stringFormatChecker implements the formatStringInterpolater interface
|
|
type stringFormatChecker struct {
|
|
args []ast.Expr
|
|
argsRequested int
|
|
currArgIndex int64
|
|
ast *ast.AST
|
|
}
|
|
|
|
func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) {
|
|
formatArg := c.args[c.currArgIndex]
|
|
valid, badID := c.verifyString(formatArg)
|
|
if !valid {
|
|
return "", stringFormatError(badID, c.typeOf(badID).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
|
|
if !valid {
|
|
return "", decimalFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) {
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
|
|
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
|
|
if !valid {
|
|
return "", fixedPointFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) {
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
|
|
if !valid {
|
|
return "", scientificFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType)
|
|
if !valid {
|
|
return "", binaryFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) {
|
|
return func(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType)
|
|
if !valid {
|
|
return "", hexFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) {
|
|
id := c.args[c.currArgIndex].ID()
|
|
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
|
|
if !valid {
|
|
return "", octalFormatError(id, c.typeOf(id).TypeName())
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) {
|
|
c.argsRequested++
|
|
c.currArgIndex = index
|
|
// return a dummy value - this is immediately passed to back to us
|
|
// through one of the FormatCallback functions, so anything will do
|
|
return types.Int(0), nil
|
|
}
|
|
|
|
func (c *stringFormatChecker) Size() int64 {
|
|
return int64(len(c.args))
|
|
}
|
|
|
|
func (c *stringFormatChecker) typeOf(id int64) *cel.Type {
|
|
return c.ast.GetType(id)
|
|
}
|
|
|
|
func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool {
|
|
t := c.typeOf(id)
|
|
if t == cel.DynType {
|
|
return true
|
|
}
|
|
for _, vt := range validTypes {
|
|
// Only check runtime type compatibility without delving deeper into parameterized types
|
|
if t.Kind() == vt.Kind() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *stringFormatChecker) verifyString(sub ast.Expr) (bool, int64) {
|
|
paramA := cel.TypeParamType("A")
|
|
paramB := cel.TypeParamType("B")
|
|
subVerified := c.verifyTypeOneOf(sub.ID(),
|
|
cel.ListType(paramA), cel.MapType(paramA, paramB),
|
|
cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType,
|
|
cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType)
|
|
if !subVerified {
|
|
return false, sub.ID()
|
|
}
|
|
switch sub.Kind() {
|
|
case ast.ListKind:
|
|
for _, e := range sub.AsList().Elements() {
|
|
// recursively verify if we're dealing with a list/map
|
|
verified, id := c.verifyString(e)
|
|
if !verified {
|
|
return false, id
|
|
}
|
|
}
|
|
return true, sub.ID()
|
|
case ast.MapKind:
|
|
for _, e := range sub.AsMap().Entries() {
|
|
// recursively verify if we're dealing with a list/map
|
|
entry := e.AsMapEntry()
|
|
verified, id := c.verifyString(entry.Key())
|
|
if !verified {
|
|
return false, id
|
|
}
|
|
verified, id = c.verifyString(entry.Value())
|
|
if !verified {
|
|
return false, id
|
|
}
|
|
}
|
|
return true, sub.ID()
|
|
default:
|
|
return true, sub.ID()
|
|
}
|
|
}
|
|
|
|
// helper routines for reporting common errors during string formatting static validation and
|
|
// runtime execution.
|
|
|
|
func binaryFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "only integers and bools can be formatted as binary, was given %s", badType)
|
|
}
|
|
|
|
func decimalFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "decimal clause can only be used on integers, was given %s", badType)
|
|
}
|
|
|
|
func fixedPointFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "fixed-point clause can only be used on doubles, was given %s", badType)
|
|
}
|
|
|
|
func hexFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "only integers, byte buffers, and strings can be formatted as hex, was given %s", badType)
|
|
}
|
|
|
|
func octalFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "octal clause can only be used on integers, was given %s", badType)
|
|
}
|
|
|
|
func scientificFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "scientific clause can only be used on doubles, was given %s", badType)
|
|
}
|
|
|
|
func stringFormatError(id int64, badType string) error {
|
|
return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType)
|
|
}
|
|
|
|
type formatError struct {
|
|
id int64
|
|
msg string
|
|
}
|
|
|
|
func newFormatError(id int64, msg string, args ...any) error {
|
|
return formatError{
|
|
id: id,
|
|
msg: fmt.Sprintf(msg, args...),
|
|
}
|
|
}
|
|
|
|
func (e formatError) Error() string {
|
|
return e.msg
|
|
}
|
|
|
|
func (e formatError) Is(target error) bool {
|
|
return e.msg == target.Error()
|
|
}
|
|
|
|
// stringArgList implements the formatListArgs interface.
|
|
type stringArgList struct {
|
|
args traits.Lister
|
|
}
|
|
|
|
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
|
|
if index >= c.args.Size().Value().(int64) {
|
|
return nil, fmt.Errorf("index %d out of range", index)
|
|
}
|
|
return c.args.Get(types.Int(index)), nil
|
|
}
|
|
|
|
func (c *stringArgList) Size() int64 {
|
|
return c.args.Size().Value().(int64)
|
|
}
|
|
|
|
// formatStringInterpolator is an interface that allows user-defined behavior
|
|
// for formatting clause implementations, as well as argument retrieval.
|
|
// Each function is expected to support the appropriate types as laid out in
|
|
// the string.format documentation, and to return an error if given an inappropriate type.
|
|
type formatStringInterpolator interface {
|
|
// String takes a ref.Val and a string representing the current locale identifier
|
|
// and returns the Val formatted as a string, or an error if one occurred.
|
|
String(ref.Val, string) (string, error)
|
|
|
|
// Decimal takes a ref.Val and a string representing the current locale identifier
|
|
// and returns the Val formatted as a decimal integer, or an error if one occurred.
|
|
Decimal(ref.Val, string) (string, error)
|
|
|
|
// Fixed takes an int pointer representing precision (or nil if none was given) and
|
|
// returns a function operating in a similar manner to String and Decimal, taking a
|
|
// ref.Val and locale and returning the appropriate string. A closure is returned
|
|
// so precision can be set without needing an additional function call/configuration.
|
|
Fixed(*int) func(ref.Val, string) (string, error)
|
|
|
|
// Scientific functions identically to Fixed, except the string returned from the closure
|
|
// is expected to be in scientific notation.
|
|
Scientific(*int) func(ref.Val, string) (string, error)
|
|
|
|
// Binary takes a ref.Val and a string representing the current locale identifier
|
|
// and returns the Val formatted as a binary integer, or an error if one occurred.
|
|
Binary(ref.Val, string) (string, error)
|
|
|
|
// Hex takes a boolean that, if true, indicates the hex string output by the returned
|
|
// closure should use uppercase letters for A-F.
|
|
Hex(bool) func(ref.Val, string) (string, error)
|
|
|
|
// Octal takes a ref.Val and a string representing the current locale identifier and
|
|
// returns the Val formatted in octal, or an error if one occurred.
|
|
Octal(ref.Val, string) (string, error)
|
|
}
|
|
|
|
// formatListArgs is an interface that allows user-defined list-like datatypes to be used
|
|
// for formatting clause implementations.
|
|
type formatListArgs interface {
|
|
// Arg returns the ref.Val at the given index, or an error if one occurred.
|
|
Arg(int64) (ref.Val, error)
|
|
|
|
// Size returns the length of the argument list.
|
|
Size() int64
|
|
}
|
|
|
|
// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
|
|
// from the provided FormatCallback and the args from the given FormatList.
|
|
func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string) (string, error) {
|
|
i := 0
|
|
argIndex := 0
|
|
var builtStr strings.Builder
|
|
for i < len(formatStr) {
|
|
if formatStr[i] == '%' {
|
|
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
|
|
err := builtStr.WriteByte('%')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing format string: %w", err)
|
|
}
|
|
i += 2
|
|
continue
|
|
} else {
|
|
argAny, err := list.Arg(int64(argIndex))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if i+1 >= len(formatStr) {
|
|
return "", errors.New("unexpected end of string")
|
|
}
|
|
if int64(argIndex) >= list.Size() {
|
|
return "", fmt.Errorf("index %d out of range", argIndex)
|
|
}
|
|
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
|
|
if refErr != nil {
|
|
return "", refErr
|
|
}
|
|
_, err = builtStr.WriteString(val)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing format string: %w", err)
|
|
}
|
|
i += numRead
|
|
argIndex++
|
|
}
|
|
} else {
|
|
err := builtStr.WriteByte(formatStr[i])
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing format string: %w", err)
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
return builtStr.String(), nil
|
|
}
|
|
|
|
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
|
|
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
|
|
func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string) (int, string, error) {
|
|
i := 1
|
|
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
|
|
i += read
|
|
if err != nil {
|
|
return -1, "", newParseFormatError("could not parse formatting clause", err)
|
|
}
|
|
|
|
valStr, err := formatter(val, locale)
|
|
if err != nil {
|
|
return -1, "", newParseFormatError("error during formatting", err)
|
|
}
|
|
return i, valStr, nil
|
|
}
|
|
|
|
func parseFormattingClause(formatStr string, callback formatStringInterpolator) (int, clauseImpl, error) {
|
|
i := 0
|
|
read, precision, err := parsePrecision(formatStr[i:])
|
|
i += read
|
|
if err != nil {
|
|
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
|
|
}
|
|
r := rune(formatStr[i])
|
|
i++
|
|
switch r {
|
|
case 's':
|
|
return i, callback.String, nil
|
|
case 'd':
|
|
return i, callback.Decimal, nil
|
|
case 'f':
|
|
return i, callback.Fixed(precision), nil
|
|
case 'e':
|
|
return i, callback.Scientific(precision), nil
|
|
case 'b':
|
|
return i, callback.Binary, nil
|
|
case 'x', 'X':
|
|
return i, callback.Hex(unicode.IsUpper(r)), nil
|
|
case 'o':
|
|
return i, callback.Octal, nil
|
|
default:
|
|
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
|
|
}
|
|
}
|
|
|
|
func parsePrecision(formatStr string) (int, *int, error) {
|
|
i := 0
|
|
if formatStr[i] != '.' {
|
|
return i, nil, nil
|
|
}
|
|
i++
|
|
var buffer strings.Builder
|
|
for {
|
|
if i >= len(formatStr) {
|
|
return -1, nil, errors.New("could not find end of precision specifier")
|
|
}
|
|
if !isASCIIDigit(rune(formatStr[i])) {
|
|
break
|
|
}
|
|
buffer.WriteByte(formatStr[i])
|
|
i++
|
|
}
|
|
precision, err := strconv.Atoi(buffer.String())
|
|
if err != nil {
|
|
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
|
|
}
|
|
return i, &precision, nil
|
|
}
|
|
|
|
func isASCIIDigit(r rune) bool {
|
|
return r <= unicode.MaxASCII && unicode.IsDigit(r)
|
|
}
|
|
|
|
type parseFormatError struct {
|
|
msg string
|
|
wrapped error
|
|
}
|
|
|
|
func newParseFormatError(msg string, wrapped error) error {
|
|
return parseFormatError{msg: msg, wrapped: wrapped}
|
|
}
|
|
|
|
func (e parseFormatError) Error() string {
|
|
return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error())
|
|
}
|
|
|
|
func (e parseFormatError) Is(target error) bool {
|
|
return e.Error() == target.Error()
|
|
}
|
|
|
|
func (e parseFormatError) Unwrap() error {
|
|
return e.wrapped
|
|
}
|
|
|
|
const (
|
|
runtimeID = int64(-1)
|
|
)
|