rebase: update kubernetes to 1.28.0 in main

updating kubernetes to 1.28.0
in the main repo.

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna
2023-08-17 07:15:28 +02:00
committed by mergify[bot]
parent b2fdc269c3
commit ff3e84ad67
706 changed files with 45252 additions and 16346 deletions

View File

@ -23,6 +23,7 @@ go_library(
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
@ -31,7 +32,7 @@ go_library(
"//interpreter:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protodesc:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -69,7 +70,7 @@ go_test(
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],

View File

@ -139,7 +139,7 @@ var (
kind: TypeKind,
runtimeType: types.TypeType,
}
//UintType represents a uint type.
// UintType represents a uint type.
UintType = &Type{
kind: UintKind,
runtimeType: types.UintType,
@ -222,7 +222,8 @@ func (t *Type) equals(other *Type) bool {
// - The from types are the same instance
// - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type
// are IsAssignableType() from the parameters of the fromType.
//
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() {
return true
@ -312,6 +313,11 @@ func NullableType(wrapped *Type) *Type {
}
}
// OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func OptionalType(param *Type) *Type {
return OpaqueType("optional", param)
}
// OpaqueType creates an abstract parameterized type with a given name.
func OpaqueType(name string, params ...*Type) *Type {
return &Type{
@ -365,7 +371,9 @@ func Variable(name string, t *Type) EnvOption {
//
// - Overloads are searched in the order they are declared
// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents
// at runtime. Empty lists and maps will result in a 'default dispatch'
//
// at runtime. Empty lists and maps will result in a 'default dispatch'
//
// - In the event that a default dispatch occurs, the first overload provided is the one invoked
//
// If you intend to use overloads which differentiate based on the key or element type of a list or
@ -405,7 +413,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
// FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt func(*functionDecl) (*functionDecl, error)
// SingletonUnaryBinding creates a singleton function defintion to be used for all function overloads.
// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
@ -431,7 +439,17 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonBinaryBinding
func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
return SingletonBinaryBinding(fn, traits...)
}
// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
@ -453,7 +471,17 @@ func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonFunctionBinding
func SingletonFunctionImpl(fn functions.FunctionOp, traits ...int) FunctionOpt {
return SingletonFunctionBinding(fn, traits...)
}
// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
@ -720,9 +748,8 @@ func (f *functionDecl) addOverload(overload *overloadDecl) error {
// Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[index] = overload
return nil
} else {
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
}
}
f.overloads = append(f.overloads, overload)
@ -1177,3 +1204,43 @@ func collectParamNames(paramNames map[string]struct{}, arg *Type) {
collectParamNames(paramNames, param)
}
}
func typeValueToKind(tv *types.TypeValue) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@ -102,15 +102,18 @@ type Env struct {
provider ref.TypeProvider
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
// Internal parser representation
prsr *parser.Parser
prsr *parser.Parser
prsrOpts []parser.Option
// Internal checker representation
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
chkMutex sync.Mutex
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
// Program options tied to the environment
progOpts []ProgramOption
@ -159,6 +162,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
provider: registry,
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
progOpts: []ProgramOption{},
}).configure(opts)
}
@ -175,14 +179,14 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
err := e.initChecker()
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, e.chkErr.Error())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssues(errs)
}
res, errs := checker.Check(pe, ast.Source(), e.chk)
res, errs := checker.Check(pe, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssues(errs)
}
@ -236,10 +240,14 @@ func (e *Env) CompileSource(src Source) (*Ast, *Issues) {
// TypeProvider are immutable, or that their underlying implementations are based on the
// ref.TypeRegistry which provides a Copy method which will be invoked by this method.
func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
if e.chkErr != nil {
return nil, e.chkErr
chk, chkErr := e.getCheckerOrError()
if chkErr != nil {
return nil, chkErr
}
prsrOptsCopy := make([]parser.Option, len(e.prsrOpts))
copy(prsrOptsCopy, e.prsrOpts)
// The type-checker is configured with Declarations. The declarations may either be provided
// as options which have not yet been validated, or may come from a previous checker instance
// whose types have already been validated.
@ -248,10 +256,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// Copy the declarations if needed.
decsCopy := []*exprpb.Decl{}
if e.chk != nil {
if chk != nil {
// If the type-checker has already been instantiated, then the e.declarations have been
// valdiated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(e.chk))
// validated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk))
} else {
// If the type-checker has not been instantiated, ensure the unvalidated declarations are
// provided to the extended Env instance.
@ -304,8 +312,11 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.functions {
funcsCopy[k] = v
}
libsCopy := make(map[string]bool, len(e.libraries))
for k, v := range e.libraries {
libsCopy[k] = v
}
// TODO: functions copy needs to happen here.
ext := &Env{
Container: e.Container,
declarations: decsCopy,
@ -315,8 +326,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
adapter: adapter,
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
}
return ext.configure(opts)
}
@ -328,6 +341,12 @@ func (e *Env) HasFeature(flag int) bool {
return has && enabled
}
// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment.
func (e *Env) HasLibrary(libName string) bool {
configured, exists := e.libraries[libName]
return exists && configured
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@ -422,8 +441,8 @@ func (e *Env) UnknownVars() interpreter.PartialActivation {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), details.State())
expr, err := AstToString(ParsedExprToAst(&exprpb.ParsedExpr{Expr: pruned}))
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
if err != nil {
return nil, err
}
@ -443,12 +462,12 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator) (checker.CostEstimate, error) {
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast)
if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
}
return checker.Cost(checked, estimator), nil
return checker.Cost(checked, estimator, opts...)
}
// configure applies a series of EnvOptions to the current environment.
@ -464,17 +483,9 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
// If the default UTC timezone fix has been enabled, make sure the library is configured
if e.HasFeature(featureDefaultUTCTimeZone) {
if _, found := e.appliedFeatures[featureDefaultUTCTimeZone]; !found {
e, err = Lib(timeUTCLibrary{})(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[featureDefaultUTCTimeZone] = true
}
e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{}))
if err != nil {
return nil, err
}
// Initialize all of the functions configured within the environment.
@ -486,7 +497,10 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
// Configure the parser.
prsrOpts := []parser.Option{parser.Macros(e.macros...)}
prsrOpts := []parser.Option{}
prsrOpts = append(prsrOpts, e.prsrOpts...)
prsrOpts = append(prsrOpts, parser.Macros(e.macros...))
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
@ -497,7 +511,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
// Ensure that the checker init happens eagerly rather than lazily.
if e.HasFeature(featureEagerlyValidateDeclarations) {
err := e.initChecker()
_, err := e.initChecker()
if err != nil {
return nil, err
}
@ -506,7 +520,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return e, nil
}
func (e *Env) initChecker() error {
func (e *Env) initChecker() (*checker.Env, error) {
e.chkOnce.Do(func() {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
@ -518,32 +532,68 @@ func (e *Env) initChecker() error {
ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the statically configured declarations.
err = ce.Add(e.declarations...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the function declarations which are derived from the FunctionDecl instances.
for _, fn := range e.functions {
fnDecl, err := functionDeclToExprDecl(fn)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
err = ce.Add(fnDecl)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
}
// Add function declarations here separately.
e.chk = ce
e.setCheckerOrError(ce, nil)
})
return e.chkErr
return e.getCheckerOrError()
}
// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner
func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) {
e.chkMutex.Lock()
e.chk = chk
e.chkErr = chkErr
e.chkMutex.Unlock()
}
// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner
func (e *Env) getCheckerOrError() (*checker.Env, error) {
e.chkMutex.Lock()
defer e.chkMutex.Unlock()
return e.chk, e.chkErr
}
// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies
// the feature if it has not already been enabled.
func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
if !e.HasFeature(feature) {
return e, nil
}
_, applied := e.appliedFeatures[feature]
if applied {
return e, nil
}
e, err := option(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[feature] = true
return e, nil
}
// Issues defines methods for inspecting the error details of parse and check calls.

View File

@ -19,14 +19,14 @@ import (
"fmt"
"reflect"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
)

View File

@ -20,10 +20,27 @@ import (
"time"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
optMapMacro = "optMap"
hasValueFunc = "hasValue"
optionalNoneFunc = "optional.none"
optionalOfFunc = "optional.of"
optionalOfNonZeroValueFunc = "optional.ofNonZeroValue"
valueFunc = "value"
unusedIterVar = "#unused"
)
// Library provides a collection of EnvOption and ProgramOption values used to configure a CEL
@ -42,10 +59,27 @@ type Library interface {
ProgramOptions() []ProgramOption
}
// SingletonLibrary refines the Library interface to ensure that libraries in this format are only
// configured once within the environment.
type SingletonLibrary interface {
Library
// LibraryName provides a namespaced name which is used to check whether the library has already
// been configured in the environment.
LibraryName() string
}
// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args,
// and to be linked to each other.
func Lib(l Library) EnvOption {
singleton, isSingleton := l.(SingletonLibrary)
return func(e *Env) (*Env, error) {
if isSingleton {
if e.HasLibrary(singleton.LibraryName()) {
return e, nil
}
e.libraries[singleton.LibraryName()] = true
}
var err error
for _, opt := range l.CompileOptions() {
e, err = opt(e)
@ -67,6 +101,11 @@ func StdLib() EnvOption {
// features documented in the specification.
type stdLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// EnvOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption {
return []EnvOption{
@ -82,6 +121,225 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
}
}
type optionalLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (optionalLibrary) LibraryName() string {
return "cel.lib.optional"
}
// CompileOptions implements the Library interface method.
func (optionalLibrary) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
paramTypeV := TypeParamType("V")
optionalTypeV := OptionalType(paramTypeV)
listTypeV := ListType(paramTypeV)
mapTypeKV := MapType(paramTypeK, paramTypeV)
return []EnvOption{
// Enable the optional syntax in the parser.
enableOptionalSyntax(),
// Introduce the optional type.
Types(types.OptionalType),
// Configure the optMap macro.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
Overload("optional_of", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
return types.OptionalOf(value)
}))),
Function(optionalOfNonZeroValueFunc,
Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
v, isZeroer := value.(traits.Zeroer)
if !isZeroer || !v.IsZeroValue() {
return types.OptionalOf(value)
}
return types.OptionalNone
}))),
Function(optionalNoneFunc,
Overload("optional_none", []*Type{}, optionalTypeV,
FunctionBinding(func(values ...ref.Val) ref.Val {
return types.OptionalNone
}))),
Function(valueFunc,
MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return opt.GetValue()
}))),
Function(hasValueFunc,
MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return types.Bool(opt.HasValue())
}))),
// Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the
// evaluation chain.
Function("or",
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)),
Function("orValue",
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)),
// OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the
// optput type.
Function(operators.OptSelect,
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)),
// OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use
// these signatures to determine type-agreement without any special handling.
Function(operators.OptIndex,
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV),
Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV),
Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
// Index overloads to accommodate using an optional value as the operand.
Function(operators.Index,
Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("optional_map_index_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "optMap() variable name must be a simple identifier",
Location: meh.OffsetLocation(varIdent.GetId()),
}
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
// ProgramOptions implements the Library interface method.
func (optionalLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
}
}
func enableOptionalSyntax() EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableOptionalSyntax(true))
return e, nil
}
}
func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
args := call.Args()
if len(args) != 2 {
return i, nil
}
switch call.Function() {
case "or":
if call.OverloadID() != "" && call.OverloadID() != "optional_or_optional" {
return i, nil
}
return &evalOptionalOr{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
case "orValue":
if call.OverloadID() != "" && call.OverloadID() != "optional_orValue_value" {
return i, nil
}
return &evalOptionalOrValue{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
default:
return i, nil
}
}
// evalOptionalOr selects between two optional values, either the first if it has a value, or
// the second optional expression is evaluated and returned.
type evalOptionalOr struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOr) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
}
// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
// its value is returned, otherwise the alternative value expression is evaluated and returned.
type evalOptionalOrValue struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOrValue) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
}
type timeUTCLibrary struct{}
func (timeUTCLibrary) CompileOptions() []EnvOption {

View File

@ -17,6 +17,7 @@ package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -26,8 +27,11 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree, or an error
// if the input arguments are not suitable for the expansion requirements for the macro in question.
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
// are not well-formed, the result of the expansion will be an error.
//
// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call
// and produces as output an Expr ast node.
@ -81,8 +85,10 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// input to produce an output list.
//
// There are two call patterns supported by map:
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {

View File

@ -29,6 +29,7 @@ import (
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
@ -61,6 +62,10 @@ const (
// on a CEL timestamp operation. This fixes the scenario where the input time
// is not already in UTC.
featureDefaultUTCTimeZone
// Enable the use of optional types in the syntax, type-system, type-checking,
// and runtime.
featureOptionalTypes
)
// EnvOption is a functional interface for configuring the environment.
@ -163,19 +168,19 @@ func Container(name string) EnvOption {
// Abbreviations can be useful when working with variables, functions, and especially types from
// multiple namespaces:
//
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
//
// Only one the qualified names above may be used as the CEL container, so at least one of these
// references must be a long qualified name within an otherwise short CEL program. Using the
// following abbreviations, the program becomes much simpler:
//
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
//
// There are a few rules for the qualified names and the simple abbreviations generated from them:
// - Qualified names must be dot-delimited, e.g. `package.subpkg.name`.
@ -188,9 +193,12 @@ func Container(name string) EnvOption {
// - Expanded abbreviations do not participate in namespace resolution.
// - Abbreviation expansion is done instead of the container search for a matching identifier.
// - Containers follow C++ namespace resolution rules with searches from the most qualified name
// to the least qualified name.
//
// to the least qualified name.
//
// - Container references within the CEL program may be relative, and are resolved to fully
// qualified names at either type-check time or program plan time, whichever comes first.
//
// qualified names at either type-check time or program plan time, whichever comes first.
//
// If there is ever a case where an identifier could be in both the container and as an
// abbreviation, the abbreviation wins as this will ensure that the meaning of a program is
@ -216,7 +224,7 @@ func Abbrevs(qualifiedNames ...string) EnvOption {
// environment by default.
//
// Note: This option must be specified after the CustomTypeProvider option when used together.
func Types(addTypes ...interface{}) EnvOption {
func Types(addTypes ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg {
@ -253,7 +261,7 @@ func Types(addTypes ...interface{}) EnvOption {
//
// TypeDescs are hermetic to a single Env object, but may be copied to other Env values via
// extension or by re-using the same EnvOption with another NewEnv() call.
func TypeDescs(descs ...interface{}) EnvOption {
func TypeDescs(descs ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg {
@ -350,8 +358,8 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// variables with the same name provided to the Eval() call. If Globals is used in a Library with
// a Lib EnvOption, vars may shadow variables provided by previously added libraries.
//
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`.
func Globals(vars interface{}) ProgramOption {
// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`.
func Globals(vars any) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err := interpreter.NewActivation(vars)
if err != nil {
@ -404,6 +412,9 @@ const (
// OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails
// cost calculation is available via func ActualCost()
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
OptCheckStringFormat EvalOption = 1 << iota
)
// EvalOptions sets one or more evaluation options which may affect the evaluation or Result.
@ -534,6 +545,13 @@ func DefaultUTCTimeZone(enabled bool) EnvOption {
return features(featureDefaultUTCTimeZone, enabled)
}
// OptionalTypes enable support for optional syntax and types in CEL. The optional value type makes
// it possible to express whether variables have been provided, whether a result has been computed,
// and in the future whether an object field path, map key value, or list index has a value.
func OptionalTypes() EnvOption {
return Lib(optionalLibrary{})
}
// features sets the given feature flags. See list of Feature constants above.
func features(flag int, enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
@ -541,3 +559,21 @@ func features(flag int, enabled bool) EnvOption {
return e, nil
}
}
// ParserRecursionLimit adjusts the AST depth the parser will tolerate.
// Defaults defined in the parser package.
func ParserRecursionLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit))
return e, nil
}
}
// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse.
// Defaults defined in the parser package.
func ParserExpressionSizeLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit))
return e, nil
}
}

View File

@ -17,21 +17,20 @@ package cel
import (
"context"
"fmt"
"math"
"sync"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Program is an evaluable view of an Ast.
type Program interface {
// Eval returns the result of an evaluation of the Ast and environment against the input vars.
//
// The vars value may either be an `interpreter.Activation` or a `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or a `map[string]any`.
//
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be:
@ -43,16 +42,16 @@ type Program interface {
// An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption`
// or `ProgramOption` values used in the creation of the evaluation environment or executable
// program.
Eval(interface{}) (ref.Val, *EvalDetails, error)
Eval(any) (ref.Val, *EvalDetails, error)
// ContextEval evaluates the program with a set of input variables and a context object in order
// to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
//
// The vars value may either be an `interpreter.Activation` or `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or `map[string]any`.
//
// The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, interface{}) (ref.Val, *EvalDetails, error)
ContextEval(context.Context, any) (ref.Val, *EvalDetails, error)
}
// NoVars returns an empty Activation.
@ -65,7 +64,7 @@ func NoVars() interpreter.Activation {
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
func PartialVars(vars interface{},
func PartialVars(vars any,
unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) {
return interpreter.NewPartialActivation(vars, unknowns...)
}
@ -207,6 +206,37 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
t, err := ExprTypeToType(ast.typeMap[id])
if err != nil {
return false, err
}
if t.kind == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if k == t.kind {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@ -268,7 +298,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}
// Eval implements the Program interface method.
func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error) {
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
@ -287,11 +317,11 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
switch v := input.(type) {
case interpreter.Activation:
vars = v
case map[string]interface{}:
case map[string]any:
vars = activationPool.Setup(v)
defer activationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
@ -307,7 +337,7 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
}
// ContextEval implements the Program interface.
func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@ -318,22 +348,17 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev
case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]interface{}:
case map[string]any:
rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
}
// Cost implements the Coster interface method.
func (p *prog) Cost() (min, max int64) {
return estimateCost(p.interpretable)
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
@ -354,7 +379,7 @@ func newProgGen(factory progFactory) (Program, error) {
}
// Eval implements the Program interface method.
func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
@ -379,7 +404,7 @@ func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
}
// ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@ -406,29 +431,6 @@ func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val
return v, det, nil
}
// Cost implements the Coster interface method.
func (gen *progGen) Cost() (min, max int64) {
// Use an empty state value since no evaluation is performed.
p, err := gen.factory(emptyEvalState, nil)
if err != nil {
return 0, math.MaxInt64
}
return estimateCost(p)
}
// EstimateCost returns the heuristic cost interval for the program.
func EstimateCost(p Program) (min, max int64) {
return estimateCost(p)
}
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(interpreter.Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
}
type ctxEvalActivation struct {
parent interpreter.Activation
interrupt <-chan struct{}
@ -438,7 +440,7 @@ type ctxEvalActivation struct {
// ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) {
func (a *ctxEvalActivation) ResolveName(name string) (any, bool) {
if name == "#interrupted" {
a.interruptCheckCount++
if a.interruptCheckCount%a.interruptCheckFrequency == 0 {
@ -461,7 +463,7 @@ func (a *ctxEvalActivation) Parent() interpreter.Activation {
func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
New: func() any {
return &ctxEvalActivation{}
},
},
@ -483,21 +485,21 @@ func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan s
}
type evalActivation struct {
vars map[string]interface{}
lazyVars map[string]interface{}
vars map[string]any
lazyVars map[string]any
}
// ResolveName looks up the value of the input variable name, if found.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() any
// - func() ref.Val
//
// The lazy binding will only be invoked once per evaluation.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
func (a *evalActivation) ResolveName(name string) (any, bool) {
v, found := a.vars[name]
if !found {
return nil, false
@ -510,7 +512,7 @@ func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
lazy := obj()
a.lazyVars[name] = lazy
return lazy, true
case func() interface{}:
case func() any:
if resolved, found := a.lazyVars[name]; found {
return resolved, true
}
@ -530,8 +532,8 @@ func (a *evalActivation) Parent() interpreter.Activation {
func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &evalActivation{lazyVars: make(map[string]interface{})}
New: func() any {
return &evalActivation{lazyVars: make(map[string]any)}
},
},
}
@ -542,13 +544,13 @@ type evalActivationPool struct {
}
// Setup initializes a pooled Activation object with the map input.
func (p *evalActivationPool) Setup(vars map[string]interface{}) *evalActivation {
func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation {
a := p.Pool.Get().(*evalActivation)
a.vars = vars
return a
}
func (p *evalActivationPool) Put(value interface{}) {
func (p *evalActivationPool) Put(value any) {
a := value.(*evalActivation)
for k := range a.lazyVars {
delete(a.lazyVars, k)
@ -559,7 +561,7 @@ func (p *evalActivationPool) Put(value interface{}) {
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]interface{} inputs
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()
// ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable

View File

@ -30,7 +30,7 @@ go_library(
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@ -54,7 +54,7 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

View File

@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
@ -173,8 +174,8 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().Type)
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value))
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
identName := ident.GetName()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
@ -185,9 +186,37 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
}
}
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = decls.Bool
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *exprpb.Type {
// Interpret as field selection, first traversing down the operand.
c.check(sel.GetOperand())
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false)
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
// If the target type is 'optional', unwrap it for the sake of this check.
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
@ -199,7 +228,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found {
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found {
resultType = fieldType.Type
}
case kindTypeParam:
@ -212,16 +241,17 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if isDynOrError(targetType) {
resultType = decls.Dyn
} else {
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
}
resultType = decls.Dyn
}
if sel.TestOnly {
resultType = decls.Bool
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return decls.NewOptionalType(resultType)
}
c.setType(e, substitute(c.mappings, resultType, false))
return resultType
}
func (c *checker) checkCall(e *exprpb.Expr) {
@ -229,15 +259,19 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
fnName := call.GetFunction()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
target := call.GetTarget()
// Regular static call with simple name.
if target == nil {
// Check for the existence of the function.
@ -359,6 +393,9 @@ func (c *checker) resolveOverload(
}
if resultType == nil {
for i, arg := range argTypes {
argTypes[i] = substitute(c.mappings, arg, true)
}
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error
return nil
@ -369,16 +406,29 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemType *exprpb.Type
for _, e := range create.GetElements() {
var elemsType *exprpb.Type
optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e))
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(c.location(e), elemsType, elemType)
}
if elemType == nil {
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemType = c.newTypeVar()
elemsType = c.newTypeVar()
}
c.setType(e, decls.NewListType(elemType))
c.setType(e, decls.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@ -392,22 +442,31 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
var keyType *exprpb.Type
var valueType *exprpb.Type
var mapKeyType *exprpb.Type
var mapValueType *exprpb.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
c.check(key)
keyType = c.joinTypes(c.location(key), keyType, c.getType(key))
mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key))
c.check(ent.Value)
valueType = c.joinTypes(c.location(ent.Value), valueType, c.getType(ent.Value))
val := ent.GetValue()
c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(c.location(val), mapValueType, valType)
}
if keyType == nil {
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
keyType = c.newTypeVar()
valueType = c.newTypeVar()
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, decls.NewMapType(keyType, valueType))
c.setType(e, decls.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
@ -449,15 +508,21 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
c.check(value)
fieldType := decls.Error
if t, found := c.lookupFieldType(
c.locationByID(ent.GetId()),
messageType.GetMessageType(),
field); found {
fieldType = t.Type
ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field)
if found {
fieldType = ft.Type
}
if !c.isAssignable(fieldType, c.getType(value)) {
c.errors.fieldTypeMismatch(
c.locationByID(ent.Id), field, fieldType, c.getType(value))
valType := c.getType(value)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType)
}
}
}

View File

@ -92,7 +92,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
@ -258,6 +261,8 @@ type coster struct {
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
@ -280,16 +285,39 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
return 0, false
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate {
c := coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
// CostOption configures flags which affect cost computations.
type CostOption func(*coster) error
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
if hasCost {
c.presenceTestCost = selectAndIdentCost
return nil
}
c.presenceTestCost = CostEstimate{Min: 0, Max: 0}
return nil
}
return c.cost(checker.GetExpr())
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
if err != nil {
return CostEstimate{}, err
}
}
return c.cost(checker.GetExpr()), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@ -340,6 +368,12 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
var sum CostEstimate
if sel.GetTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
@ -503,7 +537,10 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
switch overloadID {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
if overloadID == overloads.ExtFormatString {
return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}

View File

@ -13,7 +13,7 @@ go_library(
],
importpath = "github.com/google/cel-go/checker/decls",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],

View File

@ -16,9 +16,9 @@
package decls
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
emptypb "google.golang.org/protobuf/types/known/emptypb"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
@ -64,6 +64,12 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
ParameterTypes: paramTypes}}}
}
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used
// by type-checking steps after overload resolution.
func NewFunctionType(resultType *exprpb.Type,

View File

@ -226,7 +226,7 @@ func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !proto.Equal(existing, overload) {
if !found || !overloadsEqual(existing, overload) {
newOverloads = append(newOverloads, overload)
}
}
@ -264,6 +264,31 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found
}
// overloadsEqual returns whether two overloads have identical signatures.
//
// type parameter names are ignored as they may be specified in any order and have no bearing on overload
// equivalence
func overloadsEqual(o1, o2 *exprpb.Decl_FunctionDecl_Overload) bool {
return o1.GetOverloadId() == o2.GetOverloadId() &&
o1.GetIsInstanceFunction() == o2.GetIsInstanceFunction() &&
paramsEqual(o1.GetParams(), o2.GetParams()) &&
proto.Equal(o1.GetResultType(), o2.GetResultType())
}
// paramsEqual returns whether two lists have equal length and all types are equal
func paramsEqual(p1, p2 []*exprpb.Type) bool {
if len(p1) != len(p2) {
return false
}
for i, a := range p1 {
b := p2[i]
if !proto.Equal(a, b) {
return false
}
}
return true
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {

View File

@ -26,7 +26,7 @@ type semanticAdorner struct {
var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem interface{}) string {
func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
if !isExpr {

View File

@ -287,6 +287,8 @@ func init() {
decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches,
decls.NewOverload(overloads.Matches,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith,

View File

@ -90,6 +90,14 @@ func FormatCheckedType(t *exprpb.Type) string {
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
@ -110,12 +118,39 @@ func isDyn(t *exprpb.Type) bool {
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool {
switch kindOf(t) {
case kindError:
return true
default:
return isDyn(t)
return isError(t) || isDyn(t)
}
func isError(t *exprpb.Type) bool {
return kindOf(t) == kindError
}
func isOptional(t *exprpb.Type) bool {
if kindOf(t) == kindAbstract {
at := t.GetAbstractType()
return at.GetName() == "optional"
}
return false
}
func maybeUnwrapOptional(t *exprpb.Type) (*exprpb.Type, bool) {
if isOptional(t) {
at := t.GetAbstractType()
return at.GetParameterTypes()[0], true
}
return t, false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
return "", false
}
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
@ -236,7 +271,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// substitution for t1, and whether t2 has a type substitution in mapping m.
//
// The type t2 is a valid substitution for t1 if any of the following statements is true
// - t2 has a type substitition (t2sub) equal to t1
// - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {

View File

@ -17,7 +17,7 @@ go_library(
importpath = "github.com/google/cel-go/common",
deps = [
"//common/runes:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_x_text//width:go_default_library",
],
)

View File

@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)
@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@ -13,6 +13,6 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@ -29,7 +29,7 @@ import (
// representation of an expression.
type Adorner interface {
// GetMetadata for the input context.
GetMetadata(ctx interface{}) string
GetMetadata(ctx any) string
}
// Writer manages writing expressions to an internal string.
@ -46,7 +46,7 @@ type emptyDebugAdorner struct {
var emptyAdorner Adorner = &emptyDebugAdorner{}
func (a *emptyDebugAdorner) GetMetadata(e interface{}) string {
func (a *emptyDebugAdorner) GetMetadata(e any) string {
return ""
}
@ -170,6 +170,9 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(":")
w.Buffer(entry.GetValue())
@ -191,6 +194,9 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.append(":")
w.Buffer(entry.GetValue())
@ -269,7 +275,7 @@ func (w *debugWriter) append(s string) {
w.buffer.WriteString(s)
}
func (w *debugWriter) appendFormat(f string, args ...interface{}) {
func (w *debugWriter) appendFormat(f string, args ...any) {
w.append(fmt.Sprintf(f, args...))
}
@ -280,7 +286,7 @@ func (w *debugWriter) doIndent() {
}
}
func (w *debugWriter) adorn(e interface{}) {
func (w *debugWriter) adorn(e any) {
w.append(w.adorner.GetMetadata(e))
}

View File

@ -38,7 +38,7 @@ func NewErrors(source Source) *Errors {
}
// ReportError records an error at a source location.
func (e *Errors) ReportError(l Location, format string, args ...interface{}) {
func (e *Errors) ReportError(l Location, format string, args ...any) {
e.numErrors++
if e.numErrors > e.maxErrorsToReport {
return

View File

@ -37,6 +37,8 @@ const (
Modulo = "_%_"
Negate = "-_"
Index = "_[_]"
OptIndex = "_[?_]"
OptSelect = "_?._"
// Macros, must have a valid identifier.
Has = "has"
@ -99,6 +101,8 @@ var (
LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2},
OptIndex: {displayName: "", precedence: 1, arity: 2},
OptSelect: {displayName: "", precedence: 1, arity: 2},
}
)

View File

@ -148,6 +148,11 @@ const (
StartsWith = "startsWith"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtQuoteString = "strings_quote"
)
// String function overload names.
const (
ContainsString = "contains_string"
@ -156,6 +161,11 @@ const (
StartsWithString = "starts_with_string"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtFormatString = "string_format"
)
// Time-based functions.
const (
TimeGetFullYear = "getFullYear"

View File

@ -22,6 +22,7 @@ go_library(
"map.go",
"null.go",
"object.go",
"optional.go",
"overflow.go",
"provider.go",
"string.go",
@ -38,10 +39,8 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -68,6 +67,7 @@ go_test(
"map_test.go",
"null_test.go",
"object_test.go",
"optional_test.go",
"provider_test.go",
"string_test.go",
"timestamp_test.go",
@ -80,7 +80,7 @@ go_test(
"//common/types/ref:go_default_library",
"//test:go_default_library",
"//test/proto3pb:test_all_types_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",

View File

@ -62,7 +62,7 @@ func (b Bool) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bool) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bool) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Bool:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -114,6 +114,11 @@ func (b Bool) Equal(other ref.Val) ref.Val {
return Bool(ok && b == otherBool)
}
// IsZeroValue returns true if the boolean value is false.
func (b Bool) IsZeroValue() bool {
return b == False
}
// Negate implements the traits.Negater interface method.
func (b Bool) Negate() ref.Val {
return !b
@ -125,7 +130,7 @@ func (b Bool) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bool) Value() interface{} {
func (b Bool) Value() any {
return bool(b)
}

View File

@ -63,7 +63,7 @@ func (b Bytes) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Array, reflect.Slice:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -116,6 +116,11 @@ func (b Bytes) Equal(other ref.Val) ref.Val {
return Bool(ok && bytes.Equal(b, otherBytes))
}
// IsZeroValue returns true if the byte array is empty.
func (b Bytes) IsZeroValue() bool {
return len(b) == 0
}
// Size implements the traits.Sizer interface method.
func (b Bytes) Size() ref.Val {
return Int(len(b))
@ -127,6 +132,6 @@ func (b Bytes) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bytes) Value() interface{} {
func (b Bytes) Value() any {
return []byte(b)
}

View File

@ -78,7 +78,7 @@ func (d Double) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Double) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Double) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Float32:
v := float32(d)
@ -134,13 +134,13 @@ func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
case IntType:
i, err := doubleToInt64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(i)
case UintType:
i, err := doubleToUint64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(i)
case DoubleType:
@ -182,6 +182,11 @@ func (d Double) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if double value is 0.0
func (d Double) IsZeroValue() bool {
return float64(d) == 0.0
}
// Multiply implements traits.Multiplier.Multiply.
func (d Double) Multiply(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
@ -211,6 +216,6 @@ func (d Double) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Double) Value() interface{} {
func (d Double) Value() any {
return float64(d)
}

View File

@ -57,14 +57,14 @@ func (d Duration) Add(other ref.Val) ref.Val {
dur2 := other.(Duration)
val, err := addDurationChecked(d.Duration, dur2.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
case TimestampType:
ts := other.(Timestamp).Time
val, err := addTimeDurationChecked(ts, d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return timestampOf(val)
}
@ -90,7 +90,7 @@ func (d Duration) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Duration) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Duration) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the duration is already assignable to the desired type return it.
if reflect.TypeOf(d.Duration).AssignableTo(typeDesc) {
return d.Duration, nil
@ -138,11 +138,16 @@ func (d Duration) Equal(other ref.Val) ref.Val {
return Bool(ok && d.Duration == otherDur.Duration)
}
// IsZeroValue returns true if the duration value is zero
func (d Duration) IsZeroValue() bool {
return d.Duration == 0
}
// Negate implements traits.Negater.Negate.
func (d Duration) Negate() ref.Val {
val, err := negateDurationChecked(d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@ -165,7 +170,7 @@ func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractDurationChecked(d.Duration, subtraDur.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@ -176,7 +181,7 @@ func (d Duration) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Duration) Value() interface{} {
func (d Duration) Value() any {
return d.Duration
}

View File

@ -22,6 +22,12 @@ import (
"github.com/google/cel-go/common/types/ref"
)
// Error interface which allows types types.Err values to be treated as error values.
type Error interface {
error
ref.Val
}
// Err type which extends the built-in go error and implements ref.Val.
type Err struct {
error
@ -51,7 +57,7 @@ var (
// NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...interface{}) ref.Val {
func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)}
}
@ -62,7 +68,7 @@ func NoSuchOverloadErr() ref.Val {
// UnsupportedRefValConversionErr returns a types.NewErr instance with a no such conversion
// message that indicates that the native value could not be converted to a CEL ref.Val.
func UnsupportedRefValConversionErr(val interface{}) ref.Val {
func UnsupportedRefValConversionErr(val any) ref.Val {
return NewErr("unsupported conversion to ref.Val: (%T)%v", val, val)
}
@ -74,20 +80,20 @@ func MaybeNoSuchOverloadErr(val ref.Val) ref.Val {
// ValOrErr either returns the existing error or creates a new one.
// TODO: Audit the use of this function and standardize the error messages and codes.
func ValOrErr(val ref.Val, format string, args ...interface{}) ref.Val {
func ValOrErr(val ref.Val, format string, args ...any) ref.Val {
if val == nil || !IsUnknownOrError(val) {
return NewErr(format, args...)
}
return val
}
// wrapErr wraps an existing Go error value into a CEL Err value.
func wrapErr(err error) ref.Val {
// WrapErr wraps an existing Go error value into a CEL Err value.
func WrapErr(err error) ref.Val {
return &Err{error: err}
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (e *Err) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (e *Err) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, e.error
}
@ -114,10 +120,15 @@ func (e *Err) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (e *Err) Value() interface{} {
func (e *Err) Value() any {
return e.error
}
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()
}
// IsError returns whether the input element ref.Type or ref.Val is equal to
// the ErrType singleton.
func IsError(val ref.Val) bool {

View File

@ -66,7 +66,7 @@ func (i Int) Add(other ref.Val) ref.Val {
}
val, err := addInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -89,7 +89,7 @@ func (i Int) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (i Int) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int, reflect.Int32:
// Enums are also mapped as int32 derivations.
@ -176,7 +176,7 @@ func (i Int) ConvertToType(typeVal ref.Type) ref.Val {
case UintType:
u, err := int64ToUint64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(u)
case DoubleType:
@ -204,7 +204,7 @@ func (i Int) Divide(other ref.Val) ref.Val {
}
val, err := divideInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -226,6 +226,11 @@ func (i Int) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if integer is equal to 0
func (i Int) IsZeroValue() bool {
return i == IntZero
}
// Modulo implements traits.Modder.Modulo.
func (i Int) Modulo(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
@ -234,7 +239,7 @@ func (i Int) Modulo(other ref.Val) ref.Val {
}
val, err := moduloInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -247,7 +252,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
}
val, err := multiplyInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -256,7 +261,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
func (i Int) Negate() ref.Val {
val, err := negateInt64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -269,7 +274,7 @@ func (i Int) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractInt64Checked(int64(i), int64(subtraInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -280,7 +285,7 @@ func (i Int) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (i Int) Value() interface{} {
func (i Int) Value() any {
return int64(i)
}

View File

@ -34,7 +34,7 @@ var (
// interpreter.
type baseIterator struct{}
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion on iterators not supported")
}
@ -50,6 +50,6 @@ func (*baseIterator) Type() ref.Type {
return IteratorType
}
func (*baseIterator) Value() interface{} {
func (*baseIterator) Value() any {
return nil
}

View File

@ -25,4 +25,5 @@ var (
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
)

View File

@ -17,11 +17,13 @@ package types
import (
"fmt"
"reflect"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -40,13 +42,13 @@ var (
// NewDynamicList returns a traits.Lister with heterogenous elements.
// value should be an array of "native" types, i.e. any type that
// NativeToValue() can convert to a ref.Val.
func NewDynamicList(adapter ref.TypeAdapter, value interface{}) traits.Lister {
func NewDynamicList(adapter ref.TypeAdapter, value any) traits.Lister {
refValue := reflect.ValueOf(value)
return &baseList{
TypeAdapter: adapter,
value: value,
size: refValue.Len(),
get: func(i int) interface{} {
get: func(i int) any {
return refValue.Index(i).Interface()
},
}
@ -58,7 +60,7 @@ func NewStringList(adapter ref.TypeAdapter, elems []string) traits.Lister {
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
get: func(i int) any { return elems[i] },
}
}
@ -70,7 +72,7 @@ func NewRefValList(adapter ref.TypeAdapter, elems []ref.Val) traits.Lister {
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
get: func(i int) any { return elems[i] },
}
}
@ -80,7 +82,7 @@ func NewProtoList(adapter ref.TypeAdapter, list protoreflect.List) traits.Lister
TypeAdapter: adapter,
value: list,
size: list.Len(),
get: func(i int) interface{} { return list.Get(i).Interface() },
get: func(i int) any { return list.Get(i).Interface() },
}
}
@ -91,22 +93,25 @@ func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
TypeAdapter: adapter,
value: l,
size: len(vals),
get: func(i int) interface{} { return vals[i] },
get: func(i int) any { return vals[i] },
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
var mutableValues []ref.Val
return &mutableList{
l := &mutableList{
baseList: &baseList{
TypeAdapter: adapter,
value: mutableValues,
size: 0,
get: func(i int) interface{} { return mutableValues[i] },
},
mutableValues: mutableValues,
}
l.get = func(i int) any {
return l.mutableValues[i]
}
return l
}
// baseList points to a list containing elements of any type.
@ -114,7 +119,7 @@ func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type baseList struct {
ref.TypeAdapter
value interface{}
value any
// size indicates the number of elements within the list.
// Since objects are immutable the size of a list is static.
@ -122,7 +127,7 @@ type baseList struct {
// get returns a value at the specified integer index.
// The index is guaranteed to be checked against the list index range.
get func(int) interface{}
get func(int) any
}
// Add implements the traits.Adder interface method.
@ -157,7 +162,7 @@ func (l *baseList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the underlying list value is assignable to the reflected type return it.
if reflect.TypeOf(l.value).AssignableTo(typeDesc) {
return l.value, nil
@ -240,7 +245,7 @@ func (l *baseList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *baseList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@ -250,6 +255,11 @@ func (l *baseList) Get(index ref.Val) ref.Val {
return l.NativeToValue(l.get(ind))
}
// IsZeroValue returns true if the list is empty.
func (l *baseList) IsZeroValue() bool {
return l.size == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l)
@ -266,10 +276,24 @@ func (l *baseList) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (l *baseList) Value() interface{} {
func (l *baseList) Value() any {
return l.value
}
// String converts the list to a human readable string form.
func (l *baseList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := 0; i < l.size; i++ {
sb.WriteString(fmt.Sprintf("%v", l.get(i)))
if i != l.size-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
*baseList
@ -305,7 +329,7 @@ func (l *mutableList) ToImmutableList() traits.Lister {
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type concatList struct {
ref.TypeAdapter
value interface{}
value any
prevList traits.Lister
nextList traits.Lister
}
@ -351,8 +375,8 @@ func (l *concatList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]interface{}))
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (any, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]any))
return combined.ConvertToNative(typeDesc)
}
@ -396,7 +420,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *concatList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@ -408,6 +432,11 @@ func (l *concatList) Get(index ref.Val) ref.Val {
return l.nextList.Get(offset)
}
// IsZeroValue returns true if the list is empty.
func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l)
@ -418,15 +447,29 @@ func (l *concatList) Size() ref.Val {
return l.prevList.Size().(Int).Add(l.nextList.Size())
}
// String converts the concatenated list to a human-readable string.
func (l *concatList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := Int(0); i < l.Size().(Int); i++ {
sb.WriteString(fmt.Sprintf("%v", l.Get(i)))
if i != l.Size().(Int)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// Type implements the ref.Val interface method.
func (l *concatList) Type() ref.Type {
return ListType
}
// Value implements the ref.Val interface method.
func (l *concatList) Value() interface{} {
func (l *concatList) Value() any {
if l.value == nil {
merged := make([]interface{}, l.Size().(Int))
merged := make([]any, l.Size().(Int))
prevLen := l.prevList.Size().(Int)
for i := Int(0); i < prevLen; i++ {
merged[i] = l.prevList.Get(i).Value()
@ -469,7 +512,8 @@ func (it *listIterator) Next() ref.Val {
return nil
}
func indexOrError(index ref.Val) (int, error) {
// IndexOrError converts an input index value into either a lossless integer index or an error.
func IndexOrError(index ref.Val) (int, error) {
switch iv := index.(type) {
case Int:
return int(iv), nil

View File

@ -17,20 +17,22 @@ package types
import (
"fmt"
"reflect"
"strings"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper {
func NewDynamicMap(adapter ref.TypeAdapter, value any) traits.Mapper {
refValue := reflect.ValueOf(value)
return &baseMap{
TypeAdapter: adapter,
@ -65,7 +67,7 @@ func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Map
}
// NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values.
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]interface{}) traits.Mapper {
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]any) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
mapAccessor: newStringIfaceMapAccessor(adapter, value),
@ -125,7 +127,7 @@ type baseMap struct {
mapAccessor
// value is the native Go value upon which the map type operators.
value interface{}
value any
// size is the number of entries in the map.
size int
@ -138,7 +140,7 @@ func (m *baseMap) Contains(index ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
if reflect.TypeOf(m.value).AssignableTo(typeDesc) {
@ -275,18 +277,42 @@ func (m *baseMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *baseMap) IsZeroValue() bool {
return m.size == 0
}
// Size implements the traits.Sizer interface method.
func (m *baseMap) Size() ref.Val {
return Int(m.size)
}
// String converts the map into a human-readable string.
func (m *baseMap) String() string {
var sb strings.Builder
sb.WriteString("{")
it := m.Iterator()
i := 0
for it.HasNext() == True {
k := it.Next()
v, _ := m.Find(k)
sb.WriteString(fmt.Sprintf("%v: %v", k, v))
if i != m.size-1 {
sb.WriteString(", ")
}
i++
}
sb.WriteString("}")
return sb.String()
}
// Type implements the ref.Val interface method.
func (m *baseMap) Type() ref.Type {
return MapType
}
// Value implements the ref.Val interface method.
func (m *baseMap) Value() interface{} {
func (m *baseMap) Value() any {
return m.value
}
@ -498,7 +524,7 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interface{}) mapAccessor {
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
TypeAdapter: adapter,
mapVal: mapVal,
@ -507,7 +533,7 @@ func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interf
type stringIfaceMapAccessor struct {
ref.TypeAdapter
mapVal map[string]interface{}
mapVal map[string]any
}
// Find uses native map accesses to find the key, returning (value, true) if present.
@ -556,7 +582,7 @@ func (m *protoMap) Contains(key ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method.
//
// Note, assignment to Golang struct types is not yet supported.
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
switch typeDesc {
@ -601,9 +627,9 @@ func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
m.value.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
ntvKey := key.Interface()
ntvVal := val.Interface()
switch ntvVal.(type) {
switch pv := ntvVal.(type) {
case protoreflect.Message:
ntvVal = ntvVal.(protoreflect.Message).Interface()
ntvVal = pv.Interface()
}
if keyType == otherKeyType && valType == otherValType {
mapVal.SetMapIndex(reflect.ValueOf(ntvKey), reflect.ValueOf(ntvVal))
@ -732,6 +758,11 @@ func (m *protoMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *protoMap) IsZeroValue() bool {
return m.value.Len() == 0
}
// Iterator implements the traits.Iterable interface method.
func (m *protoMap) Iterator() traits.Iterator {
// Copy the keys to make their order stable.
@ -758,7 +789,7 @@ func (m *protoMap) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (m *protoMap) Value() interface{} {
func (m *protoMap) Value() any {
return m.value
}

View File

@ -18,9 +18,10 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@ -34,14 +35,20 @@ var (
// NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE)
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
// golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int32:
return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil
switch typeDesc {
case jsonNullType:
return structpb.NullValue_NULL_VALUE, nil
case nullReflectType:
return n, nil
}
case reflect.Ptr:
switch typeDesc {
case anyValueType:
@ -54,6 +61,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return anypb.New(pb.(proto.Message))
case jsonValueType:
return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
return nil, nil
}
case reflect.Interface:
nv := n.Value()
@ -86,12 +97,17 @@ func (n Null) Equal(other ref.Val) ref.Val {
return Bool(NullType == other.Type())
}
// IsZeroValue returns true as null always represents an absent value.
func (n Null) IsZeroValue() bool {
return true
}
// Type implements ref.Val.Type.
func (n Null) Type() ref.Type {
return NullType
}
// Value implements ref.Val.Value.
func (n Null) Value() interface{} {
func (n Null) Value() any {
return structpb.NullValue_NULL_VALUE
}

View File

@ -18,11 +18,12 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@ -52,7 +53,7 @@ func NewObject(adapter ref.TypeAdapter,
typeValue: typeValue}
}
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
srcPB := o.value
if reflect.TypeOf(srcPB).AssignableTo(typeDesc) {
return srcPB, nil
@ -133,6 +134,11 @@ func (o *protoObj) IsSet(field ref.Val) ref.Val {
return False
}
// IsZeroValue returns true if the protobuf object is empty.
func (o *protoObj) IsZeroValue() bool {
return proto.Equal(o.value, o.typeDesc.Zero())
}
func (o *protoObj) Get(index ref.Val) ref.Val {
protoFieldName, ok := index.(String)
if !ok {
@ -154,6 +160,6 @@ func (o *protoObj) Type() ref.Type {
return o.typeValue
}
func (o *protoObj) Value() interface{} {
func (o *protoObj) Value() any {
return o.value
}

View File

@ -0,0 +1,108 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"errors"
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
)
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewTypeValue("optional")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}
)
// OptionalOf returns an optional value which wraps a concrete CEL value.
func OptionalOf(value ref.Val) *Optional {
return &Optional{value: value}
}
// Optional value which points to a value if non-empty.
type Optional struct {
value ref.Val
}
// HasValue returns true if the optional has a value.
func (o *Optional) HasValue() bool {
return o.value != nil
}
// GetValue returns the wrapped value contained in the optional.
func (o *Optional) GetValue() ref.Val {
if !o.HasValue() {
return NewErr("optional.none() dereference")
}
return o.value
}
// ConvertToNative implements the ref.Val interface method.
func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !o.HasValue() {
return nil, errors.New("optional.none() dereference")
}
return o.value.ConvertToNative(typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *Optional) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case OptionalType:
return o
case TypeType:
return OptionalType
}
return NewErr("type conversion error from '%s' to '%s'", OptionalType, typeVal)
}
// Equal determines whether the values contained by two optional values are equal.
func (o *Optional) Equal(other ref.Val) ref.Val {
otherOpt, isOpt := other.(*Optional)
if !isOpt {
return False
}
if !o.HasValue() {
return Bool(!otherOpt.HasValue())
}
if !otherOpt.HasValue() {
return False
}
return o.value.Equal(otherOpt.value)
}
func (o *Optional) String() string {
if o.HasValue() {
return fmt.Sprintf("optional(%v)", o.GetValue())
}
return "optional.none()"
}
// Type implements the ref.Val interface method.
func (o *Optional) Type() ref.Type {
return OptionalType
}
// Value returns the underlying 'Value()' of the wrapped value, if present.
func (o *Optional) Value() any {
if o.value == nil {
return nil
}
return o.value.Value()
}

View File

@ -17,7 +17,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/types/pb",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@ -18,9 +18,9 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)
// NewEnumValueDescription produces an enum value description with the fully qualified enum value
// newEnumValueDescription produces an enum value description with the fully qualified enum value
// name and the enum value descriptor.
func NewEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
func newEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
return &EnumValueDescription{
enumValueName: name,
desc: desc,

View File

@ -18,32 +18,66 @@ import (
"fmt"
"google.golang.org/protobuf/reflect/protoreflect"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
)
// NewFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values declared within any scope in the file.
func NewFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) *FileDescription {
// newFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values, as well as a map of extensions declared within any scope in the file.
func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDescription, extensionMap) {
metadata := collectFileMetadata(fileDesc)
enums := make(map[string]*EnumValueDescription)
for name, enumVal := range metadata.enumValues {
enums[name] = NewEnumValueDescription(name, enumVal)
enums[name] = newEnumValueDescription(name, enumVal)
}
types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes {
types[name] = NewTypeDescription(name, msgType)
types[name] = newTypeDescription(name, msgType, pbdb.extensions)
}
fileExtMap := make(extensionMap)
for typeName, extensions := range metadata.msgExtensionMap {
messageExtMap, found := fileExtMap[typeName]
if !found {
messageExtMap = make(map[string]*FieldDescription)
}
for _, ext := range extensions {
extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor()
messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc)
}
fileExtMap[typeName] = messageExtMap
}
return &FileDescription{
name: fileDesc.Path(),
types: types,
enums: enums,
}
}, fileExtMap
}
// FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct {
name string
types map[string]*TypeDescription
enums map[string]*EnumValueDescription
}
// Copy creates a copy of the FileDescription with updated Db references within its types.
func (fd *FileDescription) Copy(pbdb *Db) *FileDescription {
typesCopy := make(map[string]*TypeDescription, len(fd.types))
for k, v := range fd.types {
typesCopy[k] = v.Copy(pbdb)
}
return &FileDescription{
name: fd.name,
types: typesCopy,
enums: fd.enums,
}
}
// GetName returns the fully qualified file path for the file.
func (fd *FileDescription) GetName() string {
return fd.name
}
// GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) {
@ -94,6 +128,10 @@ type fileMetadata struct {
msgTypes map[string]protoreflect.MessageDescriptor
// enumValues maps from fully-qualified enum value to enum value descriptor.
enumValues map[string]protoreflect.EnumValueDescriptor
// msgExtensionMap maps from the protobuf message name being extended to a set of extensions
// for the type.
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor
// TODO: support enum type definitions for use in future type-check enhancements.
}
@ -102,28 +140,38 @@ type fileMetadata struct {
func collectFileMetadata(fileDesc protoreflect.FileDescriptor) *fileMetadata {
msgTypes := make(map[string]protoreflect.MessageDescriptor)
enumValues := make(map[string]protoreflect.EnumValueDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues)
msgExtensionMap := make(map[string][]protoreflect.ExtensionDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues, msgExtensionMap)
collectEnumValues(fileDesc.Enums(), enumValues)
collectExtensions(fileDesc.Extensions(), msgExtensionMap)
return &fileMetadata{
msgTypes: msgTypes,
enumValues: enumValues,
msgTypes: msgTypes,
enumValues: enumValues,
msgExtensionMap: msgExtensionMap,
}
}
// collectMsgTypes recursively collects messages, nested messages, and nested enums into a map of
// fully qualified protobuf names to descriptors.
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors, msgTypeMap map[string]protoreflect.MessageDescriptor, enumValueMap map[string]protoreflect.EnumValueDescriptor) {
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors,
msgTypeMap map[string]protoreflect.MessageDescriptor,
enumValueMap map[string]protoreflect.EnumValueDescriptor,
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < msgTypes.Len(); i++ {
msgType := msgTypes.Get(i)
msgTypeMap[string(msgType.FullName())] = msgType
nestedMsgTypes := msgType.Messages()
if nestedMsgTypes.Len() != 0 {
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap)
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap, msgExtensionMap)
}
nestedEnumTypes := msgType.Enums()
if nestedEnumTypes.Len() != 0 {
collectEnumValues(nestedEnumTypes, enumValueMap)
}
nestedExtensions := msgType.Extensions()
if nestedExtensions.Len() != 0 {
collectExtensions(nestedExtensions, msgExtensionMap)
}
}
}
@ -139,3 +187,16 @@ func collectEnumValues(enumTypes protoreflect.EnumDescriptors, enumValueMap map[
}
}
}
func collectExtensions(extensions protoreflect.ExtensionDescriptors, msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < extensions.Len(); i++ {
ext := extensions.Get(i)
extendsMsg := string(ext.ContainingMessage().FullName())
msgExts, found := msgExtensionMap[extendsMsg]
if !found {
msgExts = []protoreflect.ExtensionDescriptor{}
}
msgExts = append(msgExts, ext)
msgExtensionMap[extendsMsg] = msgExts
}
}

View File

@ -40,13 +40,19 @@ type Db struct {
revFileDescriptorMap map[string]*FileDescription
// files contains the deduped set of FileDescriptions whose types are contained in the pb.Db.
files []*FileDescription
// extensions contains the mapping between a given type name, extension name and its FieldDescription
extensions map[string]map[string]*FieldDescription
}
// extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription
type extensionMap = map[string]map[string]*FieldDescription
var (
// DefaultDb used at evaluation time or unless overridden at check time.
DefaultDb = &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
)
@ -80,6 +86,7 @@ func NewDb() *Db {
pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
// The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow
@ -96,19 +103,34 @@ func NewDb() *Db {
// Copy creates a copy of the current database with its own internal descriptor mapping.
func (pbdb *Db) Copy() *Db {
copy := NewDb()
for k, v := range pbdb.revFileDescriptorMap {
copy.revFileDescriptorMap[k] = v
}
for _, f := range pbdb.files {
for _, fd := range pbdb.files {
hasFile := false
for _, f2 := range copy.files {
if f2 == f {
for _, fd2 := range copy.files {
if fd2 == fd {
hasFile = true
}
}
if !hasFile {
copy.files = append(copy.files, f)
fd = fd.Copy(copy)
copy.files = append(copy.files, fd)
}
for _, enumValName := range fd.GetEnumNames() {
copy.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
copy.revFileDescriptorMap[msgTypeName] = fd
}
copy.revFileDescriptorMap[fd.GetName()] = fd
}
for typeName, extFieldMap := range pbdb.extensions {
copyExtFieldMap, found := copy.extensions[typeName]
if !found {
copyExtFieldMap = make(map[string]*FieldDescription, len(extFieldMap))
}
for extFieldName, fd := range extFieldMap {
copyExtFieldMap[extFieldName] = fd
}
copy.extensions[typeName] = copyExtFieldMap
}
return copy
}
@ -137,17 +159,30 @@ func (pbdb *Db) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) (*FileD
if err == nil {
fileDesc = globalFD
}
fd = NewFileDescription(fileDesc, pbdb)
var fileExtMap extensionMap
fd, fileExtMap = newFileDescription(fileDesc, pbdb)
for _, enumValName := range fd.GetEnumNames() {
pbdb.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
pbdb.revFileDescriptorMap[msgTypeName] = fd
}
pbdb.revFileDescriptorMap[fileDesc.Path()] = fd
pbdb.revFileDescriptorMap[fd.GetName()] = fd
// Return the specific file descriptor registered.
pbdb.files = append(pbdb.files, fd)
// Index the protobuf message extensions from the file into the pbdb
for typeName, extMap := range fileExtMap {
typeExtMap, found := pbdb.extensions[typeName]
if !found {
pbdb.extensions[typeName] = extMap
continue
}
for extName, field := range extMap {
typeExtMap[extName] = field
}
}
return fd, nil
}

View File

@ -38,22 +38,23 @@ type description interface {
Zero() proto.Message
}
// NewTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// newTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// with a given descriptor.
func NewTypeDescription(typeName string, desc protoreflect.MessageDescriptor) *TypeDescription {
func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, extensions extensionMap) *TypeDescription {
msgType := dynamicpb.NewMessageType(desc)
msgZero := dynamicpb.NewMessage(desc)
fieldMap := map[string]*FieldDescription{}
fields := desc.Fields()
for i := 0; i < fields.Len(); i++ {
f := fields.Get(i)
fieldMap[string(f.Name())] = NewFieldDescription(f)
fieldMap[string(f.Name())] = newFieldDescription(f)
}
return &TypeDescription{
typeName: typeName,
desc: desc,
msgType: msgType,
fieldMap: fieldMap,
extensions: extensions,
reflectType: reflectTypeOf(msgZero),
zeroMsg: zeroValueOf(msgZero),
}
@ -66,10 +67,24 @@ type TypeDescription struct {
desc protoreflect.MessageDescriptor
msgType protoreflect.MessageType
fieldMap map[string]*FieldDescription
extensions extensionMap
reflectType reflect.Type
zeroMsg proto.Message
}
// Copy copies the type description with updated references to the Db.
func (td *TypeDescription) Copy(pbdb *Db) *TypeDescription {
return &TypeDescription{
typeName: td.typeName,
desc: td.desc,
msgType: td.msgType,
fieldMap: td.fieldMap,
extensions: pbdb.extensions,
reflectType: td.reflectType,
zeroMsg: td.zeroMsg,
}
}
// FieldMap returns a string field name to FieldDescription map.
func (td *TypeDescription) FieldMap() map[string]*FieldDescription {
return td.fieldMap
@ -78,16 +93,21 @@ func (td *TypeDescription) FieldMap() map[string]*FieldDescription {
// FieldByName returns (FieldDescription, true) if the field name is declared within the type.
func (td *TypeDescription) FieldByName(name string) (*FieldDescription, bool) {
fd, found := td.fieldMap[name]
if found {
return fd, true
}
extFieldMap, found := td.extensions[td.typeName]
if !found {
return nil, false
}
return fd, true
fd, found = extFieldMap[name]
return fd, found
}
// MaybeUnwrap accepts a proto message as input and unwraps it to a primitive CEL type if possible.
//
// This method returns the unwrapped value and 'true', else the original value and 'false'.
func (td *TypeDescription) MaybeUnwrap(msg proto.Message) (interface{}, bool, error) {
func (td *TypeDescription) MaybeUnwrap(msg proto.Message) (any, bool, error) {
return unwrap(td, msg)
}
@ -111,8 +131,8 @@ func (td *TypeDescription) Zero() proto.Message {
return td.zeroMsg
}
// NewFieldDescription creates a new field description from a protoreflect.FieldDescriptor.
func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription {
// newFieldDescription creates a new field description from a protoreflect.FieldDescriptor.
func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription {
var reflectType reflect.Type
var zeroMsg proto.Message
switch fieldDesc.Kind() {
@ -124,9 +144,17 @@ func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti
default:
reflectType = reflectTypeOf(fieldDesc.Default().Interface())
if fieldDesc.IsList() {
parentMsg := dynamicpb.NewMessage(fieldDesc.ContainingMessage())
listField := parentMsg.NewField(fieldDesc).List()
elem := listField.NewElement().Interface()
var elemValue protoreflect.Value
if fieldDesc.IsExtension() {
et := dynamicpb.NewExtensionType(fieldDesc)
elemValue = et.New().List().NewElement()
} else {
parentMsgType := fieldDesc.ContainingMessage()
parentMsg := dynamicpb.NewMessage(parentMsgType)
listField := parentMsg.NewField(fieldDesc).List()
elemValue = listField.NewElement()
}
elem := elemValue.Interface()
switch elemType := elem.(type) {
case protoreflect.Message:
elem = elemType.Interface()
@ -140,8 +168,8 @@ func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti
}
var keyType, valType *FieldDescription
if fieldDesc.IsMap() {
keyType = NewFieldDescription(fieldDesc.MapKey())
valType = NewFieldDescription(fieldDesc.MapValue())
keyType = newFieldDescription(fieldDesc.MapKey())
valType = newFieldDescription(fieldDesc.MapValue())
}
return &FieldDescription{
desc: fieldDesc,
@ -195,7 +223,7 @@ func (fd *FieldDescription) Descriptor() protoreflect.FieldDescriptor {
//
// This function implements the FieldType.IsSet function contract which can be used to operate on
// more than just protobuf field accesses; however, the target here must be a protobuf.Message.
func (fd *FieldDescription) IsSet(target interface{}) bool {
func (fd *FieldDescription) IsSet(target any) bool {
switch v := target.(type) {
case proto.Message:
pbRef := v.ProtoReflect()
@ -219,14 +247,14 @@ func (fd *FieldDescription) IsSet(target interface{}) bool {
//
// This function implements the FieldType.GetFrom function contract which can be used to operate
// on more than just protobuf field accesses; however, the target here must be a protobuf.Message.
func (fd *FieldDescription) GetFrom(target interface{}) (interface{}, error) {
func (fd *FieldDescription) GetFrom(target any) (any, error) {
v, ok := target.(proto.Message)
if !ok {
return nil, fmt.Errorf("unsupported field selection target: (%T)%v", target, target)
}
pbRef := v.ProtoReflect()
pbDesc := pbRef.Descriptor()
var fieldVal interface{}
var fieldVal any
if pbDesc == fd.desc.ContainingMessage() {
// When the target protobuf shares the same message descriptor instance as the field
// descriptor, use the cached field descriptor value.
@ -289,7 +317,7 @@ func (fd *FieldDescription) IsList() bool {
//
// This function returns the unwrapped value and 'true' on success, or the original value
// and 'false' otherwise.
func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (interface{}, bool, error) {
func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (any, bool, error) {
return unwrapDynamic(fd, msg)
}
@ -362,7 +390,7 @@ func checkedWrap(t *exprpb.Type) *exprpb.Type {
// input message is a *dynamicpb.Message which obscures the typing information from Go.
//
// Returns the unwrapped value and 'true' if unwrapped, otherwise the input value and 'false'.
func unwrap(desc description, msg proto.Message) (interface{}, bool, error) {
func unwrap(desc description, msg proto.Message) (any, bool, error) {
switch v := msg.(type) {
case *anypb.Any:
dynMsg, err := v.UnmarshalNew()
@ -418,7 +446,7 @@ func unwrap(desc description, msg proto.Message) (interface{}, bool, error) {
// unwrapDynamic unwraps a reflected protobuf Message value.
//
// Returns the unwrapped value and 'true' if unwrapped, otherwise the input value and 'false'.
func unwrapDynamic(desc description, refMsg protoreflect.Message) (interface{}, bool, error) {
func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, error) {
msg := refMsg.Interface()
if !refMsg.IsValid() {
msg = desc.Zero()
@ -508,7 +536,7 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (interface{},
// reflectTypeOf intercepts the reflect.Type call to ensure that dynamicpb.Message types preserve
// well-known protobuf reflected types expected by the CEL type system.
func reflectTypeOf(val interface{}) reflect.Type {
func reflectTypeOf(val any) reflect.Type {
switch v := val.(type) {
case proto.Message:
return reflect.TypeOf(zeroValueOf(v))

View File

@ -19,11 +19,12 @@ import (
"reflect"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
@ -195,7 +196,7 @@ func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
// providing support for custom proto-based types.
//
// This method should be the inverse of ref.Val.ConvertToNative.
func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val {
func (p *protoTypeRegistry) NativeToValue(value any) ref.Val {
if val, found := nativeToValue(p, value); found {
return val
}
@ -249,7 +250,7 @@ var (
)
// NativeToValue implements the ref.TypeAdapter interface.
func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val {
func (a *defaultTypeAdapter) NativeToValue(value any) ref.Val {
if val, found := nativeToValue(a, value); found {
return val
}
@ -258,7 +259,7 @@ func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val {
// nativeToValue returns the converted (ref.Val, true) of a conversion is found,
// otherwise (nil, false)
func nativeToValue(a ref.TypeAdapter, value interface{}) (ref.Val, bool) {
func nativeToValue(a ref.TypeAdapter, value any) (ref.Val, bool) {
switch v := value.(type) {
case nil:
return NullValue, true
@ -364,7 +365,7 @@ func nativeToValue(a ref.TypeAdapter, value interface{}) (ref.Val, bool) {
// specializations for common map types.
case map[string]string:
return NewStringStringMap(a, v), true
case map[string]interface{}:
case map[string]any:
return NewStringInterfaceMap(a, v), true
case map[ref.Val]ref.Val:
return NewRefValMap(a, v), true
@ -479,9 +480,12 @@ func msgSetField(target protoreflect.Message, field *pb.FieldDescription, val re
if err != nil {
return fieldTypeConversionError(field, err)
}
switch v.(type) {
if v == nil {
return nil
}
switch pv := v.(type) {
case proto.Message:
v = v.(proto.Message).ProtoReflect()
v = pv.ProtoReflect()
}
target.Set(field.Descriptor(), protoreflect.ValueOf(v))
return nil
@ -495,6 +499,9 @@ func msgSetListField(target protoreflect.List, listField *pb.FieldDescription, l
if err != nil {
return fieldTypeConversionError(listField, err)
}
if elemVal == nil {
continue
}
switch ev := elemVal.(type) {
case proto.Message:
elemVal = ev.ProtoReflect()
@ -519,9 +526,12 @@ func msgSetMapField(target protoreflect.Map, mapField *pb.FieldDescription, mapV
if err != nil {
return fieldTypeConversionError(mapField, err)
}
switch v.(type) {
if v == nil {
continue
}
switch pv := v.(type) {
case proto.Message:
v = v.(proto.Message).ProtoReflect()
v = pv.ProtoReflect()
}
target.Set(protoreflect.ValueOf(k).MapKey(), protoreflect.ValueOf(v))
}

View File

@ -13,7 +13,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/types/ref",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
],

View File

@ -39,8 +39,6 @@ type TypeProvider interface {
// FieldFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
//
// Used during type-checking only.
FindFieldType(messageType string, fieldName string) (*FieldType, bool)
// NewValue creates a new type value from a qualified name and map of field
@ -55,7 +53,7 @@ type TypeProvider interface {
// TypeAdapter converts native Go values of varying type and complexity to equivalent CEL values.
type TypeAdapter interface {
// NativeToValue converts the input `value` to a CEL `ref.Val`.
NativeToValue(value interface{}) Val
NativeToValue(value any) Val
}
// TypeRegistry allows third-parties to add custom types to CEL. Not all `TypeProvider`
@ -97,7 +95,7 @@ type FieldType struct {
}
// FieldTester is used to test field presence on an input object.
type FieldTester func(target interface{}) bool
type FieldTester func(target any) bool
// FieldGetter is used to get the field value from an input object, if set.
type FieldGetter func(target interface{}) (interface{}, error)
type FieldGetter func(target any) (any, error)

View File

@ -37,9 +37,18 @@ type Type interface {
type Val interface {
// ConvertToNative converts the Value to a native Go struct according to the
// reflected type description, or error if the conversion is not feasible.
ConvertToNative(typeDesc reflect.Type) (interface{}, error)
//
// The ConvertToNative method is intended to be used to support conversion between CEL types
// and native types during object creation expressions or by clients who need to adapt the,
// returned CEL value into an equivalent Go value instance.
//
// When implementing or using ConvertToNative, the following guidelines apply:
// - Use ConvertToNative when marshalling CEL evaluation results to native types.
// - Do not use ConvertToNative within CEL extension functions.
// - Document whether your implementation supports non-CEL field types, such as Go or Protobuf.
ConvertToNative(typeDesc reflect.Type) (any, error)
// ConvertToType supports type conversions between value types supported by the expression language.
// ConvertToType supports type conversions between CEL value types supported by the expression language.
ConvertToType(typeValue Type) Val
// Equal returns true if the `other` value has the same type and content as the implementing struct.
@ -50,5 +59,5 @@ type Val interface {
// Value returns the raw value of the instance which may not be directly compatible with the expression
// language types.
Value() interface{}
Value() any
}

View File

@ -72,7 +72,7 @@ func (s String) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (s String) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.String:
if reflect.TypeOf(s).AssignableTo(typeDesc) {
@ -154,6 +154,11 @@ func (s String) Equal(other ref.Val) ref.Val {
return Bool(ok && s == otherString)
}
// IsZeroValue returns true if the string is empty.
func (s String) IsZeroValue() bool {
return len(s) == 0
}
// Match implements traits.Matcher.Match.
func (s String) Match(pattern ref.Val) ref.Val {
pat, ok := pattern.(String)
@ -189,7 +194,7 @@ func (s String) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (s String) Value() interface{} {
func (s String) Value() any {
return string(s)
}

View File

@ -89,7 +89,7 @@ func (t Timestamp) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t Timestamp) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (t Timestamp) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the timestamp is already assignable to the desired type return it.
if reflect.TypeOf(t.Time).AssignableTo(typeDesc) {
return t.Time, nil
@ -138,6 +138,11 @@ func (t Timestamp) Equal(other ref.Val) ref.Val {
return Bool(ok && t.Time.Equal(otherTime.Time))
}
// IsZeroValue returns true if the timestamp is epoch 0.
func (t Timestamp) IsZeroValue() bool {
return t.IsZero()
}
// Receive implements traits.Receiver.Receive.
func (t Timestamp) Receive(function string, overload string, args []ref.Val) ref.Val {
switch len(args) {
@ -160,14 +165,14 @@ func (t Timestamp) Subtract(subtrahend ref.Val) ref.Val {
dur := subtrahend.(Duration)
val, err := subtractTimeDurationChecked(t.Time, dur.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return timestampOf(val)
case TimestampType:
t2 := subtrahend.(Timestamp).Time
val, err := subtractTimeChecked(t.Time, t2)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@ -180,7 +185,7 @@ func (t Timestamp) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (t Timestamp) Value() interface{} {
func (t Timestamp) Value() any {
return t.Time
}
@ -288,7 +293,7 @@ func timeZone(tz ref.Val, visitor timestampVisitor) timestampVisitor {
if ind == -1 {
loc, err := time.LoadLocation(val)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return visitor(t.In(loc))
}
@ -297,11 +302,11 @@ func timeZone(tz ref.Val, visitor timestampVisitor) timestampVisitor {
// in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes.
hr, err := strconv.Atoi(string(val[0:ind]))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
min, err := strconv.Atoi(string(val[ind+1:]))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
var offset int
if string(val[0]) == "-" {

View File

@ -20,6 +20,7 @@ go_library(
"receiver.go",
"sizer.go",
"traits.go",
"zeroer.go",
],
importpath = "github.com/google/cel-go/common/types/traits",
deps = [

View File

@ -1,4 +1,4 @@
// Copyright 2020 Google LLC
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,24 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
package traits
import "math"
// TODO: remove Coster.
// Coster calculates the heuristic cost incurred during evaluation.
// Deprecated: Please migrate cel.EstimateCost, it supports length estimates for input data and cost estimates for
// extension functions.
type Coster interface {
Cost() (min, max int64)
}
// estimateCost returns the heuristic cost interval for the program.
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
// Zeroer interface for testing whether a CEL value is a zero value for its type.
type Zeroer interface {
// IsZeroValue indicates whether the object is the zero value for the type.
IsZeroValue() bool
}

View File

@ -53,7 +53,7 @@ func NewObjectTypeValue(name string) *TypeValue {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *TypeValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (t *TypeValue) ConvertToNative(typeDesc reflect.Type) (any, error) {
// TODO: replace the internal type representation with a proto-value.
return nil, fmt.Errorf("type conversion not supported for 'type'")
}
@ -97,6 +97,6 @@ func (t *TypeValue) TypeName() string {
}
// Value implements ref.Val.Value.
func (t *TypeValue) Value() interface{} {
func (t *TypeValue) Value() any {
return t.name
}

View File

@ -59,7 +59,7 @@ func (i Uint) Add(other ref.Val) ref.Val {
}
val, err := addUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@ -82,7 +82,7 @@ func (i Uint) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (i Uint) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Uint, reflect.Uint32:
v, err := uint64ToUint32Checked(uint64(i))
@ -149,7 +149,7 @@ func (i Uint) ConvertToType(typeVal ref.Type) ref.Val {
case IntType:
v, err := uint64ToInt64Checked(uint64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(v)
case UintType:
@ -172,7 +172,7 @@ func (i Uint) Divide(other ref.Val) ref.Val {
}
div, err := divideUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(div)
}
@ -194,6 +194,11 @@ func (i Uint) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if the uint is zero.
func (i Uint) IsZeroValue() bool {
return i == 0
}
// Modulo implements traits.Modder.Modulo.
func (i Uint) Modulo(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
@ -202,7 +207,7 @@ func (i Uint) Modulo(other ref.Val) ref.Val {
}
mod, err := moduloUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(mod)
}
@ -215,7 +220,7 @@ func (i Uint) Multiply(other ref.Val) ref.Val {
}
val, err := multiplyUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@ -228,7 +233,7 @@ func (i Uint) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractUint64Checked(uint64(i), uint64(subtraUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@ -239,7 +244,7 @@ func (i Uint) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (i Uint) Value() interface{} {
func (i Uint) Value() any {
return uint64(i)
}

View File

@ -30,7 +30,7 @@ var (
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (u Unknown) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (u Unknown) ConvertToNative(typeDesc reflect.Type) (any, error) {
return u.Value(), nil
}
@ -50,7 +50,7 @@ func (u Unknown) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (u Unknown) Value() interface{} {
func (u Unknown) Value() any {
return []int64(u)
}

View File

@ -9,14 +9,30 @@ go_library(
srcs = [
"encoders.go",
"guards.go",
"math.go",
"native.go",
"protos.go",
"sets.go",
"strings.go",
],
importpath = "github.com/google/cel-go/ext",
visibility = ["//visibility:public"],
deps = [
"//cel:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_x_text//language:go_default_library",
"@org_golang_x_text//message:go_default_library",
],
)
@ -25,6 +41,10 @@ go_test(
size = "small",
srcs = [
"encoders_test.go",
"math_test.go",
"native_test.go",
"protos_test.go",
"sets_test.go",
"strings_test.go",
],
embed = [
@ -32,5 +52,17 @@ go_test(
],
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
],
)

View File

@ -3,6 +3,30 @@
CEL extensions are a related set of constants, functions, macros, or other
features which may not be covered by the core CEL spec.
## Bindings
Returns a cel.EnvOption to configure support for local variable bindings
in expressions.
# Cel.Bind
Binds a simple identifier to an initialization expression which may be used
in a subsequenct result expression. Bindings may also be nested within each
other.
cel.bind(<varName>, <initExpr>, <resultExpr>)
Examples:
cel.bind(a, 'hello',
cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
// Avoid a list allocation within the exists comprehension.
cel.bind(valid_values, [a, b, c],
[d, e, f].exists(elem, elem in valid_values))
Local bindings are not guaranteed to be evaluated before use.
## Encoders
Encoding utilies for marshalling data into standardized representations.
@ -31,6 +55,156 @@ Example:
base64.encode(b'hello') // return 'aGVsbG8='
## Math
Math helper macros and functions.
Note, all macros use the 'math' namespace; however, at the time of macro
expansion the namespace looks just like any other identifier. If you are
currently using a variable named 'math', the macro will likely work just as
intended; however, there is some chance for collision.
### Math.Greatest
Returns the greatest valued number present in the arguments to the macro.
Greatest is a variable argument count macro which must take at least one
argument. Simple numeric and list literals are supported as valid argument
types; however, other literals will be flagged as errors during macro
expansion. If the argument expression does not resolve to a numeric or
list(numeric) type during type-checking, or during runtime then an error
will be produced. If a list argument is empty, this too will produce an
error.
math.greatest(<arg>, ...) -> <double|int|uint>
Examples:
math.greatest(1) // 1
math.greatest(1u, 2u) // 2u
math.greatest(-42.0, -21.5, -100.0) // -21.5
math.greatest([-42.0, -21.5, -100.0]) // -21.5
math.greatest(numbers) // numbers must be list(numeric)
math.greatest() // parse error
math.greatest('string') // parse error
math.greatest(a, b) // check-time error if a or b is non-numeric
math.greatest(dyn('string')) // runtime error
### Math.Least
Returns the least valued number present in the arguments to the macro.
Least is a variable argument count macro which must take at least one
argument. Simple numeric and list literals are supported as valid argument
types; however, other literals will be flagged as errors during macro
expansion. If the argument expression does not resolve to a numeric or
list(numeric) type during type-checking, or during runtime then an error
will be produced. If a list argument is empty, this too will produce an error.
math.least(<arg>, ...) -> <double|int|uint>
Examples:
math.least(1) // 1
math.least(1u, 2u) // 1u
math.least(-42.0, -21.5, -100.0) // -100.0
math.least([-42.0, -21.5, -100.0]) // -100.0
math.least(numbers) // numbers must be list(numeric)
math.least() // parse error
math.least('string') // parse error
math.least(a, b) // check-time error if a or b is non-numeric
math.least(dyn('string')) // runtime error
## Protos
Protos configure extended macros and functions for proto manipulation.
Note, all macros use the 'proto' namespace; however, at the time of macro
expansion the namespace looks just like any other identifier. If you are
currently using a variable named 'proto', the macro will likely work just as
you intend; however, there is some chance for collision.
### Protos.GetExt
Macro which generates a select expression that retrieves an extension field
from the input proto2 syntax message. If the field is not set, the default
value forthe extension field is returned according to safe-traversal semantics.
proto.getExt(<msg>, <fully.qualified.extension.name>) -> <field-type>
Example:
proto.getExt(msg, google.expr.proto2.test.int32_ext) // returns int value
### Protos.HasExt
Macro which generates a test-only select expression that determines whether
an extension field is set on a proto2 syntax message.
proto.hasExt(<msg>, <fully.qualified.extension.name>) -> <bool>
Example:
proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false
## Sets
Sets provides set relationship tests.
There is no set type within CEL, and while one may be introduced in the
future, there are cases where a `list` type is known to behave like a set.
For such cases, this library provides some basic functionality for
determining set containment, equivalence, and intersection.
### Sets.Contains
Returns whether the first list argument contains all elements in the second
list argument. The list may contain elements of any type and standard CEL
equality is used to determine whether a value exists in both lists. If the
second list is empty, the result will always return true.
sets.contains(list(T), list(T)) -> bool
Examples:
sets.contains([], []) // true
sets.contains([], [1]) // false
sets.contains([1, 2, 3, 4], [2, 3]) // true
sets.contains([1, 2.0, 3u], [1.0, 2u, 3]) // true
### Sets.Equivalent
Returns whether the first and second list are set equivalent. Lists are set
equivalent if for every item in the first list, there is an element in the
second which is equal. The lists may not be of the same size as they do not
guarantee the elements within them are unique, so size does not factor into
the computation.
sets.equivalent(list(T), list(T)) -> bool
Examples:
sets.equivalent([], []) // true
sets.equivalent([1], [1, 1]) // true
sets.equivalent([1], [1u, 1.0]) // true
sets.equivalent([1, 2, 3], [3u, 2.0, 1]) // true
### Sets.Intersects
Returns whether the first list has at least one element whose value is equal
to an element in the second list. If either list is empty, the result will
be false.
sets.intersects(list(T), list(T)) -> bool
Examples:
sets.intersects([1], []) // false
sets.intersects([1], [1, 2]) // true
sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]]) // true
## Strings
Extended functions for string manipulation. As a general note, all indices are
@ -70,6 +244,23 @@ Examples:
'hello mellow'.indexOf('ello', 2) // returns 7
'hello mellow'.indexOf('ello', 20) // error
### Join
Returns a new string where the elements of string list are concatenated.
The function also accepts an optional separator which is placed between
elements in the resulting string.
<list<string>>.join() -> <string>
<list<string>>.join(<string>) -> <string>
Examples:
['hello', 'mellow'].join() // returns 'hellomellow'
['hello', 'mellow'].join(' ') // returns 'hello mellow'
[].join() // returns ''
[].join('/') // returns ''
### LastIndexOf
Returns the integer index of the last occurrence of the search string. If the
@ -105,6 +296,20 @@ Examples:
'TacoCat'.lowerAscii() // returns 'tacocat'
'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
### Quote
**Introduced in version 1**
Takes the given string and makes it safe to print (without any formatting due to escape sequences).
If any invalid UTF-8 characters are encountered, they are replaced with \uFFFD.
strings.quote(<string>)
Examples:
strings.quote('single-quote with "double quote"') // returns '"single-quote with \"double quote\""'
strings.quote("two escape sequences \a\n") // returns '"two escape sequences \\a\\n"'
### Replace
Returns a new string based on the target, which replaces the occurrences of a

100
vendor/github.com/google/cel-go/ext/bindings.go generated vendored Normal file
View File

@ -0,0 +1,100 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Bindings returns a cel.EnvOption to configure support for local variable
// bindings in expressions.
//
// # Cel.Bind
//
// Binds a simple identifier to an initialization expression which may be used
// in a subsequenct result expression. Bindings may also be nested within each
// other.
//
// cel.bind(<varName>, <initExpr>, <resultExpr>)
//
// Examples:
//
// cel.bind(a, 'hello',
// cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
//
// // Avoid a list allocation within the exists comprehension.
// cel.bind(valid_values, [a, b, c],
// [d, e, f].exists(elem, elem in valid_values))
//
// Local bindings are not guaranteed to be evaluated before use.
func Bindings() cel.EnvOption {
return cel.Lib(celBindings{})
}
const (
celNamespace = "cel"
bindMacro = "bind"
unusedIterVar = "#unused"
)
type celBindings struct{}
func (celBindings) LibraryName() string {
return "cel.lib.ext.cel.bindings"
}
func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.NewReceiverMacro(bindMacro, 3, celBind),
),
}
}
func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "cel.bind() variable names must be simple identifers",
Location: meh.OffsetLocation(varIdent.GetId()),
}
}
varInit := args[1]
resultExpr := args[2]
return meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
varInit,
meh.LiteralBool(false),
meh.Ident(varName),
resultExpr,
), nil
}

View File

@ -26,34 +26,38 @@ import (
// Encoders returns a cel.EnvOption to configure extended functions for string, byte, and object
// encodings.
//
// Base64.Decode
// # Base64.Decode
//
// Decodes base64-encoded string to bytes.
//
// This function will return an error if the string input is not base64-encoded.
//
// base64.decode(<string>) -> <bytes>
// base64.decode(<string>) -> <bytes>
//
// Examples:
//
// base64.decode('aGVsbG8=') // return b'hello'
// base64.decode('aGVsbG8') // error
// base64.decode('aGVsbG8=') // return b'hello'
// base64.decode('aGVsbG8') // error
//
// Base64.Encode
// # Base64.Encode
//
// Encodes bytes to a base64-encoded string.
//
// base64.encode(<bytes>) -> <string>
// base64.encode(<bytes>) -> <string>
//
// Examples:
//
// base64.encode(b'hello') // return b'aGVsbG8='
// base64.encode(b'hello') // return b'aGVsbG8='
func Encoders() cel.EnvOption {
return cel.Lib(encoderLib{})
}
type encoderLib struct{}
func (encoderLib) LibraryName() string {
return "cel.lib.ext.encoders"
}
func (encoderLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Function("base64.decode",

View File

@ -17,6 +17,7 @@ package ext
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// function invocation guards for common call signatures within extension functions.
@ -48,3 +49,15 @@ func listStringOrError(strs []string, err error) ref.Val {
}
return types.DefaultTypeAdapter.NativeToValue(strs)
}
func macroTargetMatchesNamespace(ns string, target *exprpb.Expr) bool {
switch target.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
if target.GetIdentExpr().GetName() != ns {
return false
}
return true
default:
return false
}
}

388
vendor/github.com/google/cel-go/ext/math.go generated vendored Normal file
View File

@ -0,0 +1,388 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Math returns a cel.EnvOption to configure namespaced math helper macros and
// functions.
//
// Note, all macros use the 'math' namespace; however, at the time of macro
// expansion the namespace looks just like any other identifier. If you are
// currently using a variable named 'math', the macro will likely work just as
// intended; however, there is some chance for collision.
//
// # Math.Greatest
//
// Returns the greatest valued number present in the arguments to the macro.
//
// Greatest is a variable argument count macro which must take at least one
// argument. Simple numeric and list literals are supported as valid argument
// types; however, other literals will be flagged as errors during macro
// expansion. If the argument expression does not resolve to a numeric or
// list(numeric) type during type-checking, or during runtime then an error
// will be produced. If a list argument is empty, this too will produce an
// error.
//
// math.greatest(<arg>, ...) -> <double|int|uint>
//
// Examples:
//
// math.greatest(1) // 1
// math.greatest(1u, 2u) // 2u
// math.greatest(-42.0, -21.5, -100.0) // -21.5
// math.greatest([-42.0, -21.5, -100.0]) // -21.5
// math.greatest(numbers) // numbers must be list(numeric)
//
// math.greatest() // parse error
// math.greatest('string') // parse error
// math.greatest(a, b) // check-time error if a or b is non-numeric
// math.greatest(dyn('string')) // runtime error
//
// # Math.Least
//
// Returns the least valued number present in the arguments to the macro.
//
// Least is a variable argument count macro which must take at least one
// argument. Simple numeric and list literals are supported as valid argument
// types; however, other literals will be flagged as errors during macro
// expansion. If the argument expression does not resolve to a numeric or
// list(numeric) type during type-checking, or during runtime then an error
// will be produced. If a list argument is empty, this too will produce an
// error.
//
// math.least(<arg>, ...) -> <double|int|uint>
//
// Examples:
//
// math.least(1) // 1
// math.least(1u, 2u) // 1u
// math.least(-42.0, -21.5, -100.0) // -100.0
// math.least([-42.0, -21.5, -100.0]) // -100.0
// math.least(numbers) // numbers must be list(numeric)
//
// math.least() // parse error
// math.least('string') // parse error
// math.least(a, b) // check-time error if a or b is non-numeric
// math.least(dyn('string')) // runtime error
func Math() cel.EnvOption {
return cel.Lib(mathLib{})
}
const (
mathNamespace = "math"
leastMacro = "least"
greatestMacro = "greatest"
minFunc = "math.@min"
maxFunc = "math.@max"
)
type mathLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (mathLib) LibraryName() string {
return "cel.lib.ext.math"
}
// CompileOptions implements the Library interface method.
func (mathLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// math.least(num, ...)
cel.NewReceiverVarArgMacro(leastMacro, mathLeast),
// math.greatest(num, ...)
cel.NewReceiverVarArgMacro(greatestMacro, mathGreatest),
),
cel.Function(minFunc,
cel.Overload("math_@min_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_int", []*cel.Type{cel.IntType}, cel.IntType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_uint", []*cel.Type{cel.UintType}, cel.UintType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_double_double", []*cel.Type{cel.DoubleType, cel.DoubleType}, cel.DoubleType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_int", []*cel.Type{cel.IntType, cel.IntType}, cel.IntType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_uint", []*cel.Type{cel.UintType, cel.UintType}, cel.UintType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_uint", []*cel.Type{cel.IntType, cel.UintType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_double", []*cel.Type{cel.IntType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_double_int", []*cel.Type{cel.DoubleType, cel.IntType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_double_uint", []*cel.Type{cel.DoubleType, cel.UintType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_int", []*cel.Type{cel.UintType, cel.IntType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_double", []*cel.Type{cel.UintType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_list_double", []*cel.Type{cel.ListType(cel.DoubleType)}, cel.DoubleType,
cel.UnaryBinding(minList)),
cel.Overload("math_@min_list_int", []*cel.Type{cel.ListType(cel.IntType)}, cel.IntType,
cel.UnaryBinding(minList)),
cel.Overload("math_@min_list_uint", []*cel.Type{cel.ListType(cel.UintType)}, cel.UintType,
cel.UnaryBinding(minList)),
),
cel.Function(maxFunc,
cel.Overload("math_@max_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_int", []*cel.Type{cel.IntType}, cel.IntType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_uint", []*cel.Type{cel.UintType}, cel.UintType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_double_double", []*cel.Type{cel.DoubleType, cel.DoubleType}, cel.DoubleType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_int", []*cel.Type{cel.IntType, cel.IntType}, cel.IntType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_uint", []*cel.Type{cel.UintType, cel.UintType}, cel.UintType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_uint", []*cel.Type{cel.IntType, cel.UintType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_double", []*cel.Type{cel.IntType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_double_int", []*cel.Type{cel.DoubleType, cel.IntType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_double_uint", []*cel.Type{cel.DoubleType, cel.UintType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_int", []*cel.Type{cel.UintType, cel.IntType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_double", []*cel.Type{cel.UintType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_list_double", []*cel.Type{cel.ListType(cel.DoubleType)}, cel.DoubleType,
cel.UnaryBinding(maxList)),
cel.Overload("math_@max_list_int", []*cel.Type{cel.ListType(cel.IntType)}, cel.IntType,
cel.UnaryBinding(maxList)),
cel.Overload("math_@max_list_uint", []*cel.Type{cel.ListType(cel.UintType)}, cel.UintType,
cel.UnaryBinding(maxList)),
),
}
}
// ProgramOptions implements the Library interface method.
func (mathLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, &common.Error{
Message: "math.least() requires at least one argument",
Location: meh.OffsetLocation(target.GetId()),
}
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(minFunc, args[0]), nil
}
return nil, &common.Error{
Message: "math.least() invalid single argument value",
Location: meh.OffsetLocation(args[0].GetId()),
}
case 2:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, meh.NewList(args...)), nil
}
}
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, &common.Error{
Message: "math.greatest() requires at least one argument",
Location: meh.OffsetLocation(target.GetId()),
}
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(maxFunc, args[0]), nil
}
return nil, &common.Error{
Message: "math.greatest() invalid single argument value",
Location: meh.OffsetLocation(args[0].GetId()),
}
case 2:
err := checkInvalidArgs(meh, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, meh.NewList(args...)), nil
}
}
func identity(val ref.Val) ref.Val {
return val
}
func minPair(first, second ref.Val) ref.Val {
cmp, ok := first.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(first)
}
out := cmp.Compare(second)
if types.IsUnknownOrError(out) {
return maybeSuffixError(out, "math.@min")
}
if out == types.IntOne {
return second
}
return first
}
func minList(numList ref.Val) ref.Val {
l := numList.(traits.Lister)
size := l.Size().(types.Int)
if size == types.IntZero {
return types.NewErr("math.@min(list) argument must not be empty")
}
min := l.Get(types.IntZero)
for i := types.IntOne; i < size; i++ {
min = minPair(min, l.Get(i))
}
switch min.Type() {
case types.IntType, types.DoubleType, types.UintType, types.UnknownType:
return min
default:
return types.NewErr("no such overload: math.@min")
}
}
func maxPair(first, second ref.Val) ref.Val {
cmp, ok := first.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(first)
}
out := cmp.Compare(second)
if types.IsUnknownOrError(out) {
return maybeSuffixError(out, "math.@max")
}
if out == types.IntNegOne {
return second
}
return first
}
func maxList(numList ref.Val) ref.Val {
l := numList.(traits.Lister)
size := l.Size().(types.Int)
if size == types.IntZero {
return types.NewErr("math.@max(list) argument must not be empty")
}
max := l.Get(types.IntZero)
for i := types.IntOne; i < size; i++ {
max = maxPair(max, l.Get(i))
}
switch max.Type() {
case types.IntType, types.DoubleType, types.UintType, types.UnknownType:
return max
default:
return types.NewErr("no such overload: math.@max")
}
}
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *common.Error {
for _, arg := range args {
err := checkInvalidArgLiteral(funcName, arg)
if err != nil {
return &common.Error{
Message: err.Error(),
Location: meh.OffsetLocation(arg.GetId()),
}
}
}
return nil
}
func checkInvalidArgLiteral(funcName string, arg *exprpb.Expr) error {
if !isValidArgType(arg) {
return fmt.Errorf("%s simple literal arguments must be numeric", funcName)
}
return nil
}
func isValidArgType(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
c := arg.GetConstExpr()
switch c.GetConstantKind().(type) {
case *exprpb.Constant_DoubleValue, *exprpb.Constant_Int64Value, *exprpb.Constant_Uint64Value:
return true
default:
return false
}
case *exprpb.Expr_ListExpr, *exprpb.Expr_StructExpr:
return false
default:
return true
}
}
func isListLiteralWithValidArgs(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ListExpr:
list := arg.GetListExpr()
if len(list.GetElements()) == 0 {
return false
}
for _, e := range list.GetElements() {
if !isValidArgType(e) {
return false
}
}
return true
}
return false
}
func maybeSuffixError(val ref.Val, suffix string) ref.Val {
if types.IsError(val) {
msg := val.(*types.Err).String()
if !strings.Contains(msg, suffix) {
return types.NewErr("%s: %s", msg, suffix)
}
}
return val
}

574
vendor/github.com/google/cel-go/ext/native.go generated vendored Normal file
View File

@ -0,0 +1,574 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"reflect"
"strings"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
var (
nativeObjTraitMask = traits.FieldTesterType | traits.IndexerType
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
)
// NativeTypes creates a type provider which uses reflect.Type and reflect.Value instances
// to produce type definitions that can be used within CEL.
//
// All struct types in Go are exposed to CEL via their simple package name and struct type name:
//
// ```go
// package identity
//
// type Account struct {
// ID int
// }
//
// ```
//
// The type `identity.Account` would be exported to CEL using the same qualified name, e.g.
// `identity.Account{ID: 1234}` would create a new `Account` instance with the `ID` field
// populated.
//
// Only exported fields are exposed via NativeTypes, and the type-mapping between Go and CEL
// is as follows:
//
// | Go type | CEL type |
// |-------------------------------------|-----------|
// | bool | bool |
// | []byte | bytes |
// | float32, float64 | double |
// | int, int8, int16, int32, int64 | int |
// | string | string |
// | uint, uint8, uint16, uint32, uint64 | uint |
// | time.Duration | duration |
// | time.Time | timestamp |
// | array, slice | list |
// | map | map |
//
// Please note, if you intend to configure support for proto messages in addition to native
// types, you will need to provide the protobuf types before the golang native types. The
// same advice holds if you are using custom type adapters and type providers. The native type
// provider composes over whichever type adapter and provider is configured in the cel.Env at
// the time that it is invoked.
func NativeTypes(refTypes ...any) cel.EnvOption {
return func(env *cel.Env) (*cel.Env, error) {
tp, err := newNativeTypeProvider(env.TypeAdapter(), env.TypeProvider(), refTypes...)
if err != nil {
return nil, err
}
env, err = cel.CustomTypeAdapter(tp)(env)
if err != nil {
return nil, err
}
return cel.CustomTypeProvider(tp)(env)
}
}
func newNativeTypeProvider(adapter ref.TypeAdapter, provider ref.TypeProvider, refTypes ...any) (*nativeTypeProvider, error) {
nativeTypes := make(map[string]*nativeType, len(refTypes))
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
t, err := newNativeType(rt)
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
case reflect.Value:
t, err := newNativeType(rt.Type())
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
default:
return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt)
}
}
return &nativeTypeProvider{
nativeTypes: nativeTypes,
baseAdapter: adapter,
baseProvider: provider,
}, nil
}
type nativeTypeProvider struct {
nativeTypes map[string]*nativeType
baseAdapter ref.TypeAdapter
baseProvider ref.TypeProvider
}
// EnumValue proxies to the ref.TypeProvider configured at the times the NativeTypes
// option was configured.
func (tp *nativeTypeProvider) EnumValue(enumName string) ref.Val {
return tp.baseProvider.EnumValue(enumName)
}
// FindIdent looks up natives type instances by qualified identifier, and if not found
// proxies to the composed ref.TypeProvider.
func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) {
if t, found := tp.nativeTypes[typeName]; found {
return t, true
}
return tp.baseProvider.FindIdent(typeName)
}
// FindType looks up CEL type-checker type definition by qualified identifier, and if not found
// proxies to the composed ref.TypeProvider.
func (tp *nativeTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
if _, found := tp.nativeTypes[typeName]; found {
return decls.NewTypeType(decls.NewObjectType(typeName)), true
}
return tp.baseProvider.FindType(typeName)
}
// FindFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed ref.TypeProvider
func (tp *nativeTypeProvider) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
t, found := tp.nativeTypes[typeName]
if !found {
return tp.baseProvider.FindFieldType(typeName, fieldName)
}
refField, isDefined := t.hasField(fieldName)
if !found || !isDefined {
return nil, false
}
exprType, ok := convertToExprType(refField.Type)
if !ok {
return nil, false
}
return &ref.FieldType{
Type: exprType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return getFieldValue(tp, refField), nil
},
}, true
}
// NewValue implements the ref.TypeProvider interface method.
func (tp *nativeTypeProvider) NewValue(typeName string, fields map[string]ref.Val) ref.Val {
t, found := tp.nativeTypes[typeName]
if !found {
return tp.baseProvider.NewValue(typeName, fields)
}
refPtr := reflect.New(t.refType)
refVal := refPtr.Elem()
for fieldName, val := range fields {
refFieldDef, isDefined := t.hasField(fieldName)
if !isDefined {
return types.NewErr("no such field: %s", fieldName)
}
fieldVal, err := val.ConvertToNative(refFieldDef.Type)
if err != nil {
return types.NewErr(err.Error())
}
refField := refVal.FieldByIndex(refFieldDef.Index)
refFieldVal := reflect.ValueOf(fieldVal)
refField.Set(refFieldVal)
}
return tp.NativeToValue(refPtr.Interface())
}
// NewValue adapts native values to CEL values and will proxy to the composed type adapter
// for non-native types.
func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
if val == nil {
return types.NullValue
}
if v, ok := val.(ref.Val); ok {
return v
}
rawVal := reflect.ValueOf(val)
refVal := rawVal
if refVal.Kind() == reflect.Ptr {
refVal = reflect.Indirect(refVal)
}
// This isn't quite right if you're also supporting proto,
// but maybe an acceptable limitation.
switch refVal.Kind() {
case reflect.Array, reflect.Slice:
switch val := val.(type) {
case []byte:
return tp.baseAdapter.NativeToValue(val)
default:
return types.NewDynamicList(tp, val)
}
case reflect.Map:
return types.NewDynamicMap(tp, val)
case reflect.Struct:
switch val := val.(type) {
case proto.Message, *pb.Map, protoreflect.List, protoreflect.Message, protoreflect.Value,
time.Time:
return tp.baseAdapter.NativeToValue(val)
default:
return newNativeObject(tp, val, rawVal)
}
default:
return tp.baseAdapter.NativeToValue(val)
}
}
// convertToExprType converts the Golang reflect.Type to a protobuf exprpb.Type.
func convertToExprType(refType reflect.Type) (*exprpb.Type, bool) {
switch refType.Kind() {
case reflect.Bool:
return decls.Bool, true
case reflect.Float32, reflect.Float64:
return decls.Double, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if refType == durationType {
return decls.Duration, true
}
return decls.Int, true
case reflect.String:
return decls.String, true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return decls.Uint, true
case reflect.Array, reflect.Slice:
refElem := refType.Elem()
if refElem == reflect.TypeOf(byte(0)) {
return decls.Bytes, true
}
elemType, ok := convertToExprType(refElem)
if !ok {
return nil, false
}
return decls.NewListType(elemType), true
case reflect.Map:
keyType, ok := convertToExprType(refType.Key())
if !ok {
return nil, false
}
// Ensure the key type is a int, bool, uint, string
elemType, ok := convertToExprType(refType.Elem())
if !ok {
return nil, false
}
return decls.NewMapType(keyType, elemType), true
case reflect.Struct:
if refType == timestampType {
return decls.Timestamp, true
}
return decls.NewObjectType(
fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
), true
case reflect.Pointer:
if refType.Implements(pbMsgInterfaceType) {
pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage)
return decls.NewObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
}
return convertToExprType(refType.Elem())
}
return nil, false
}
func newNativeObject(adapter ref.TypeAdapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
return &nativeObj{
TypeAdapter: adapter,
val: val,
valType: valType,
refValue: refValue,
}
}
type nativeObj struct {
ref.TypeAdapter
val any
valType *nativeType
refValue reflect.Value
}
// ConvertToNative implements the ref.Val interface method.
//
// CEL does not have a notion of pointers, so whether a field is a pointer or value
// is handled as part of this conversion step.
func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if o.refValue.Type() == typeDesc {
return o.val, nil
}
if o.refValue.Kind() == reflect.Pointer && o.refValue.Type().Elem() == typeDesc {
return o.refValue.Elem().Interface(), nil
}
if typeDesc.Kind() == reflect.Pointer && o.refValue.Type() == typeDesc.Elem() {
ptr := reflect.New(typeDesc.Elem())
ptr.Elem().Set(o.refValue)
return ptr.Interface(), nil
}
switch typeDesc {
case jsonValueType:
jsonStruct, err := o.ConvertToNative(jsonStructType)
if err != nil {
return nil, err
}
return structpb.NewStructValue(jsonStruct.(*structpb.Struct)), nil
case jsonStructType:
refVal := reflect.Indirect(o.refValue)
refType := refVal.Type()
fields := make(map[string]*structpb.Value, refVal.NumField())
for i := 0; i < refVal.NumField(); i++ {
fieldType := refType.Field(i)
fieldValue := refVal.Field(i)
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJSONVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
return nil, fmt.Errorf("type conversion error from '%v' to '%v'", o.Type(), typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *nativeObj) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.TypeType:
return o.valType
default:
if typeVal.TypeName() == o.valType.typeName {
return o
}
}
return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), typeVal)
}
// Equal implements the ref.Val interface method.
//
// Note, that in Golang a pointer to a value is not equal to the value it contains.
// In CEL pointers and values to which they point are equal.
func (o *nativeObj) Equal(other ref.Val) ref.Val {
otherNtv, ok := other.(*nativeObj)
if !ok {
return types.False
}
val := o.val
otherVal := otherNtv.val
refVal := o.refValue
otherRefVal := otherNtv.refValue
if refVal.Kind() != otherRefVal.Kind() {
if refVal.Kind() == reflect.Pointer {
val = refVal.Elem().Interface()
} else if otherRefVal.Kind() == reflect.Pointer {
otherVal = otherRefVal.Elem().Interface()
}
}
return types.Bool(reflect.DeepEqual(val, otherVal))
}
// IsZeroValue indicates whether the contained Golang value is a zero value.
//
// Golang largely follows proto3 semantics for zero values.
func (o *nativeObj) IsZeroValue() bool {
return reflect.Indirect(o.refValue).IsZero()
}
// IsSet tests whether a field which is defined is set to a non-default value.
func (o *nativeObj) IsSet(field ref.Val) ref.Val {
refField, refErr := o.getReflectedField(field)
if refErr != nil {
return refErr
}
return types.Bool(!refField.IsZero())
}
// Get returns the value fo a field name.
func (o *nativeObj) Get(field ref.Val) ref.Val {
refField, refErr := o.getReflectedField(field)
if refErr != nil {
return refErr
}
return adaptFieldValue(o, refField)
}
func (o *nativeObj) getReflectedField(field ref.Val) (reflect.Value, ref.Val) {
fieldName, ok := field.(types.String)
if !ok {
return reflect.Value{}, types.MaybeNoSuchOverloadErr(field)
}
fieldNameStr := string(fieldName)
refField, isDefined := o.valType.hasField(fieldNameStr)
if !isDefined {
return reflect.Value{}, types.NewErr("no such field: %s", fieldName)
}
refVal := reflect.Indirect(o.refValue)
return refVal.FieldByIndex(refField.Index), nil
}
// Type implements the ref.Val interface method.
func (o *nativeObj) Type() ref.Type {
return o.valType
}
// Value implements the ref.Val interface method.
func (o *nativeObj) Value() any {
return o.val
}
func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
}
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}
return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
}, nil
}
type nativeType struct {
typeName string
refType reflect.Type
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *nativeType) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion error for type to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *nativeType) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.TypeType:
return types.TypeType
}
return types.NewErr("type conversion error from '%s' to '%s'", types.TypeType, typeVal)
}
// Equal returns true of both type names are equal to each other.
func (t *nativeType) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return types.Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait implements the ref.Type interface method.
func (t *nativeType) HasTrait(trait int) bool {
return nativeObjTraitMask&trait == trait
}
// String implements the strings.Stringer interface method.
func (t *nativeType) String() string {
return t.typeName
}
// Type implements the ref.Val interface method.
func (t *nativeType) Type() ref.Type {
return types.TypeType
}
// TypeName implements the ref.Type interface method.
func (t *nativeType) TypeName() string {
return t.typeName
}
// Value implements the ref.Val interface method.
func (t *nativeType) Value() any {
return t.typeName
}
// hasField returns whether a field name has a corresponding Golang reflect.StructField
func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
f, found := t.refType.FieldByName(fieldName)
if !found || !f.IsExported() || !isSupportedType(f.Type) {
return reflect.StructField{}, false
}
return f, true
}
func adaptFieldValue(adapter ref.TypeAdapter, refField reflect.Value) ref.Val {
return adapter.NativeToValue(getFieldValue(adapter, refField))
}
func getFieldValue(adapter ref.TypeAdapter, refField reflect.Value) any {
if refField.IsZero() {
switch refField.Kind() {
case reflect.Array, reflect.Slice:
return types.NewDynamicList(adapter, []ref.Val{})
case reflect.Map:
return types.NewDynamicMap(adapter, map[ref.Val]ref.Val{})
case reflect.Struct:
if refField.Type() == timestampType {
return types.Timestamp{Time: time.Unix(0, 0)}
}
return reflect.New(refField.Type()).Elem().Interface()
case reflect.Pointer:
return reflect.New(refField.Type().Elem()).Interface()
}
}
return refField.Interface()
}
func simplePkgAlias(pkgPath string) string {
paths := strings.Split(pkgPath, "/")
if len(paths) == 0 {
return ""
}
return paths[len(paths)-1]
}
func isValidObjectType(refType reflect.Type) bool {
return refType.Kind() == reflect.Struct
}
func isSupportedType(refType reflect.Type) bool {
switch refType.Kind() {
case reflect.Chan, reflect.Complex64, reflect.Complex128, reflect.Func, reflect.UnsafePointer, reflect.Uintptr:
return false
case reflect.Array, reflect.Slice:
return isSupportedType(refType.Elem())
case reflect.Map:
return isSupportedType(refType.Key()) && isSupportedType(refType.Elem())
}
return true
}
var (
pbMsgInterfaceType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem()
timestampType = reflect.TypeOf(time.Now())
durationType = reflect.TypeOf(time.Nanosecond)
)

145
vendor/github.com/google/cel-go/ext/protos.go generated vendored Normal file
View File

@ -0,0 +1,145 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Protos returns a cel.EnvOption to configure extended macros and functions for
// proto manipulation.
//
// Note, all macros use the 'proto' namespace; however, at the time of macro
// expansion the namespace looks just like any other identifier. If you are
// currently using a variable named 'proto', the macro will likely work just as
// intended; however, there is some chance for collision.
//
// # Protos.GetExt
//
// Macro which generates a select expression that retrieves an extension field
// from the input proto2 syntax message. If the field is not set, the default
// value forthe extension field is returned according to safe-traversal semantics.
//
// proto.getExt(<msg>, <fully.qualified.extension.name>) -> <field-type>
//
// Examples:
//
// proto.getExt(msg, google.expr.proto2.test.int32_ext) // returns int value
//
// # Protos.HasExt
//
// Macro which generates a test-only select expression that determines whether
// an extension field is set on a proto2 syntax message.
//
// proto.hasExt(<msg>, <fully.qualified.extension.name>) -> <bool>
//
// Examples:
//
// proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false
func Protos() cel.EnvOption {
return cel.Lib(protoLib{})
}
var (
protoNamespace = "proto"
hasExtension = "hasExt"
getExtension = "getExt"
)
type protoLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (protoLib) LibraryName() string {
return "cel.lib.ext.protos"
}
// CompileOptions implements the Library interface method.
func (protoLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// proto.getExt(msg, select_expression)
cel.NewReceiverMacro(getExtension, 2, getProtoExt),
// proto.hasExt(msg, select_expression)
cel.NewReceiverMacro(hasExtension, 2, hasProtoExt),
),
}
}
// ProgramOptions implements the Library interface method.
func (protoLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
// hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message.
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extensionField, err := getExtFieldName(meh, args[1])
if err != nil {
return nil, err
}
return meh.PresenceTest(args[0], extensionField), nil
}
// getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message.
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extFieldName, err := getExtFieldName(meh, args[1])
if err != nil {
return nil, err
}
return meh.Select(args[0], extFieldName), nil
}
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *common.Error) {
isValid := false
extensionField := ""
switch expr.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
extensionField, isValid = validateIdentifier(expr)
}
if !isValid {
return "", &common.Error{
Message: "invalid extension field",
Location: meh.OffsetLocation(expr.GetId()),
}
}
return extensionField, nil
}
func validateIdentifier(expr *exprpb.Expr) (string, bool) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
return expr.GetIdentExpr().GetName(), true
case *exprpb.Expr_SelectExpr:
sel := expr.GetSelectExpr()
if sel.GetTestOnly() {
return "", false
}
opStr, isIdent := validateIdentifier(sel.GetOperand())
if !isIdent {
return "", false
}
return opStr + "." + sel.GetField(), true
default:
return "", false
}
}

138
vendor/github.com/google/cel-go/ext/sets.go generated vendored Normal file
View File

@ -0,0 +1,138 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// Sets returns a cel.EnvOption to configure namespaced set relationship
// functions.
//
// There is no set type within CEL, and while one may be introduced in the
// future, there are cases where a `list` type is known to behave like a set.
// For such cases, this library provides some basic functionality for
// determining set containment, equivalence, and intersection.
//
// # Sets.Contains
//
// Returns whether the first list argument contains all elements in the second
// list argument. The list may contain elements of any type and standard CEL
// equality is used to determine whether a value exists in both lists. If the
// second list is empty, the result will always return true.
//
// sets.contains(list(T), list(T)) -> bool
//
// Examples:
//
// sets.contains([], []) // true
// sets.contains([], [1]) // false
// sets.contains([1, 2, 3, 4], [2, 3]) // true
// sets.contains([1, 2.0, 3u], [1.0, 2u, 3]) // true
//
// # Sets.Equivalent
//
// Returns whether the first and second list are set equivalent. Lists are set
// equivalent if for every item in the first list, there is an element in the
// second which is equal. The lists may not be of the same size as they do not
// guarantee the elements within them are unique, so size does not factor into
// the computation.
//
// Examples:
//
// sets.equivalent([], []) // true
// sets.equivalent([1], [1, 1]) // true
// sets.equivalent([1], [1u, 1.0]) // true
// sets.equivalent([1, 2, 3], [3u, 2.0, 1]) // true
//
// # Sets.Intersects
//
// Returns whether the first list has at least one element whose value is equal
// to an element in the second list. If either list is empty, the result will
// be false.
//
// Examples:
//
// sets.intersects([1], []) // false
// sets.intersects([1], [1, 2]) // true
// sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]]) // true
func Sets() cel.EnvOption {
return cel.Lib(setsLib{})
}
type setsLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (setsLib) LibraryName() string {
return "cel.lib.ext.sets"
}
// CompileOptions implements the Library interface method.
func (setsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
return []cel.EnvOption{
cel.Function("sets.contains",
cel.Overload("list_sets_contains_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsContains))),
cel.Function("sets.equivalent",
cel.Overload("list_sets_equivalent_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsEquivalent))),
cel.Function("sets.intersects",
cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsIntersects))),
}
}
// ProgramOptions implements the Library interface method.
func (setsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func setsIntersects(listA, listB ref.Val) ref.Val {
lA := listA.(traits.Lister)
lB := listB.(traits.Lister)
it := lA.Iterator()
for it.HasNext() == types.True {
exists := lB.Contains(it.Next())
if exists == types.True {
return types.True
}
}
return types.False
}
func setsContains(list, sublist ref.Val) ref.Val {
l := list.(traits.Lister)
sub := sublist.(traits.Lister)
it := sub.Iterator()
for it.HasNext() == types.True {
exists := l.Contains(it.Next())
if exists != types.True {
return exists
}
}
return types.True
}
func setsEquivalent(listA, listB ref.Val) ref.Val {
aContainsB := setsContains(listA, listB)
if aContainsB != types.True {
return aContainsB
}
return setsContains(listB, listA)
}

View File

@ -19,32 +19,92 @@ package ext
import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
const (
defaultLocale = "en-US"
defaultPrecision = 6
)
// Strings returns a cel.EnvOption to configure extended functions for string manipulation.
// As a general note, all indices are zero-based.
//
// CharAt
// # CharAt
//
// Returns the character at the given position. If the position is negative, or greater than
// the length of the string, the function will produce an error:
//
// <string>.charAt(<int>) -> <string>
// <string>.charAt(<int>) -> <string>
//
// Examples:
//
// 'hello'.charAt(4) // return 'o'
// 'hello'.charAt(5) // return ''
// 'hello'.charAt(-1) // error
// 'hello'.charAt(4) // return 'o'
// 'hello'.charAt(5) // return ''
// 'hello'.charAt(-1) // error
//
// IndexOf
// # Format
//
// Introduced at version: 1
//
// Returns a new string with substitutions being performed, printf-style.
// The valid formatting clauses are:
//
// `%s` - substitutes a string. This can also be used on bools, lists, maps, bytes,
// Duration and Timestamp, in addition to all numerical types (int, uint, and double).
// Note that the dot/period decimal separator will always be used when printing a list
// or map that contains a double, and that null can be passed (which results in the
// string "null") in addition to types.
// `%d` - substitutes an integer.
// `%f` - substitutes a double with fixed-point precision. The default precision is 6, but
// this can be adjusted. The strings `Infinity`, `-Infinity`, and `NaN` are also valid input
// for this clause.
// `%e` - substitutes a double in scientific notation. The default precision is 6, but this
// can be adjusted.
// `%b` - substitutes an integer with its equivalent binary string. Can also be used on bools.
// `%x` - substitutes an integer with its equivalent in hexadecimal, or if given a string or
// bytes, will output each character's equivalent in hexadecimal.
// `%X` - same as above, but with A-F capitalized.
// `%o` - substitutes an integer with its equivalent in octal.
//
// <string>.format(<list>) -> <string>
//
// Examples:
//
// "this is a string: %s\nand an integer: %d".format(["str", 42]) // returns "this is a string: str\nand an integer: 42"
// "a double substituted with %%s: %s".format([64.2]) // returns "a double substituted with %s: 64.2"
// "string type: %s".format([type(string)]) // returns "string type: string"
// "timestamp: %s".format([timestamp("2023-02-03T23:31:20+00:00")]) // returns "timestamp: 2023-02-03T23:31:20Z"
// "duration: %s".format([duration("1h45m47s")]) // returns "duration: 6347s"
// "%f".format([3.14]) // returns "3.140000"
// "scientific notation: %e".format([2.71828]) // returns "scientific notation: 2.718280\u202f\u00d7\u202f10\u2070\u2070"
// "5 in binary: %b".format([5]), // returns "5 in binary; 101"
// "26 in hex: %x".format([26]), // returns "26 in hex: 1a"
// "26 in hex (uppercase): %X".format([26]) // returns "26 in hex (uppercase): 1A"
// "30 in octal: %o".format([30]) // returns "30 in octal: 36"
// "a map inside a list: %s".format([[1, 2, 3, {"a": "x", "b": "y", "c": "z"}]]) // returns "a map inside a list: [1, 2, 3, {"a":"x", "b":"y", "c":"d"}]"
// "true bool: %s - false bool: %s\nbinary bool: %b".format([true, false, true]) // returns "true bool: true - false bool: false\nbinary bool: 1"
//
// Passing an incorrect type (an integer to `%s`) is considered an error, as well as attempting
// to use more formatting clauses than there are arguments (`%d %d %d` while passing two ints, for instance).
// If compile-time checking is enabled, and the formatting string is a constant, and the argument list is a literal,
// then letting any arguments go unused/unformatted is also considered an error.
//
// # IndexOf
//
// Returns the integer index of the first occurrence of the search string. If the search string is
// not found the function returns -1.
@ -52,19 +112,19 @@ import (
// The function also accepts an optional position from which to begin the substring search. If the
// substring is the empty string, the index where the search starts is returned (zero or custom).
//
// <string>.indexOf(<string>) -> <int>
// <string>.indexOf(<string>, <int>) -> <int>
// <string>.indexOf(<string>) -> <int>
// <string>.indexOf(<string>, <int>) -> <int>
//
// Examples:
//
// 'hello mellow'.indexOf('') // returns 0
// 'hello mellow'.indexOf('ello') // returns 1
// 'hello mellow'.indexOf('jello') // returns -1
// 'hello mellow'.indexOf('', 2) // returns 2
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
// 'hello mellow'.indexOf('') // returns 0
// 'hello mellow'.indexOf('ello') // returns 1
// 'hello mellow'.indexOf('jello') // returns -1
// 'hello mellow'.indexOf('', 2) // returns 2
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
//
// Join
// # Join
//
// Returns a new string where the elements of string list are concatenated.
//
@ -75,12 +135,12 @@ import (
//
// Examples:
//
// ['hello', 'mellow'].join() // returns 'hellomellow'
// ['hello', 'mellow'].join(' ') // returns 'hello mellow'
// [].join() // returns ''
// [].join('/') // returns ''
// ['hello', 'mellow'].join() // returns 'hellomellow'
// ['hello', 'mellow'].join(' ') // returns 'hello mellow'
// [].join() // returns ''
// [].join('/') // returns ''
//
// LastIndexOf
// # LastIndexOf
//
// Returns the integer index at the start of the last occurrence of the search string. If the
// search string is not found the function returns -1.
@ -89,31 +149,45 @@ import (
// considered as the beginning of the substring match. If the substring is the empty string,
// the index where the search starts is returned (string length or custom).
//
// <string>.lastIndexOf(<string>) -> <int>
// <string>.lastIndexOf(<string>, <int>) -> <int>
// <string>.lastIndexOf(<string>) -> <int>
// <string>.lastIndexOf(<string>, <int>) -> <int>
//
// Examples:
//
// 'hello mellow'.lastIndexOf('') // returns 12
// 'hello mellow'.lastIndexOf('ello') // returns 7
// 'hello mellow'.lastIndexOf('jello') // returns -1
// 'hello mellow'.lastIndexOf('ello', 6) // returns 1
// 'hello mellow'.lastIndexOf('ello', -1) // error
// 'hello mellow'.lastIndexOf('') // returns 12
// 'hello mellow'.lastIndexOf('ello') // returns 7
// 'hello mellow'.lastIndexOf('jello') // returns -1
// 'hello mellow'.lastIndexOf('ello', 6) // returns 1
// 'hello mellow'.lastIndexOf('ello', -1) // error
//
// LowerAscii
// # LowerAscii
//
// Returns a new string where all ASCII characters are lower-cased.
//
// This function does not perform Unicode case-mapping for characters outside the ASCII range.
//
// <string>.lowerAscii() -> <string>
// <string>.lowerAscii() -> <string>
//
// Examples:
//
// 'TacoCat'.lowerAscii() // returns 'tacocat'
// 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
// 'TacoCat'.lowerAscii() // returns 'tacocat'
// 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
//
// Replace
// # Quote
//
// Introduced in version: 1
//
// Takes the given string and makes it safe to print (without any formatting due to escape sequences).
// If any invalid UTF-8 characters are encountered, they are replaced with \uFFFD.
//
// strings.quote(<string>)
//
// Examples:
//
// strings.quote('single-quote with "double quote"') // returns '"single-quote with \"double quote\""'
// strings.quote("two escape sequences \a\n") // returns '"two escape sequences \\a\\n"'
//
// # Replace
//
// Returns a new string based on the target, which replaces the occurrences of a search string
// with a replacement string if present. The function accepts an optional limit on the number of
@ -122,17 +196,17 @@ import (
// When the replacement limit is 0, the result is the original string. When the limit is a negative
// number, the function behaves the same as replace all.
//
// <string>.replace(<string>, <string>) -> <string>
// <string>.replace(<string>, <string>, <int>) -> <string>
// <string>.replace(<string>, <string>) -> <string>
// <string>.replace(<string>, <string>, <int>) -> <string>
//
// Examples:
//
// 'hello hello'.replace('he', 'we') // returns 'wello wello'
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
// 'hello hello'.replace('he', 'we') // returns 'wello wello'
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
//
// Split
// # Split
//
// Returns a list of strings split from the input by the given separator. The function accepts
// an optional argument specifying a limit on the number of substrings produced by the split.
@ -141,18 +215,18 @@ import (
// target string to split. When the limit is a negative number, the function behaves the same as
// split all.
//
// <string>.split(<string>) -> <list<string>>
// <string>.split(<string>, <int>) -> <list<string>>
// <string>.split(<string>) -> <list<string>>
// <string>.split(<string>, <int>) -> <list<string>>
//
// Examples:
//
// 'hello hello hello'.split(' ') // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ', 0) // returns []
// 'hello hello hello'.split(' ', 1) // returns ['hello hello hello']
// 'hello hello hello'.split(' ', 2) // returns ['hello', 'hello hello']
// 'hello hello hello'.split(' ', -1) // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ') // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ', 0) // returns []
// 'hello hello hello'.split(' ', 1) // returns ['hello hello hello']
// 'hello hello hello'.split(' ', 2) // returns ['hello', 'hello hello']
// 'hello hello hello'.split(' ', -1) // returns ['hello', 'hello', 'hello']
//
// Substring
// # Substring
//
// Returns the substring given a numeric range corresponding to character positions. Optionally
// may omit the trailing range for a substring from a given character position until the end of
@ -162,48 +236,102 @@ import (
// error to specify an end range that is lower than the start range, or for either the start or end
// index to be negative or exceed the string length.
//
// <string>.substring(<int>) -> <string>
// <string>.substring(<int>, <int>) -> <string>
// <string>.substring(<int>) -> <string>
// <string>.substring(<int>, <int>) -> <string>
//
// Examples:
//
// 'tacocat'.substring(4) // returns 'cat'
// 'tacocat'.substring(0, 4) // returns 'taco'
// 'tacocat'.substring(-1) // error
// 'tacocat'.substring(2, 1) // error
// 'tacocat'.substring(4) // returns 'cat'
// 'tacocat'.substring(0, 4) // returns 'taco'
// 'tacocat'.substring(-1) // error
// 'tacocat'.substring(2, 1) // error
//
// Trim
// # Trim
//
// Returns a new string which removes the leading and trailing whitespace in the target string.
// The trim function uses the Unicode definition of whitespace which does not include the
// zero-width spaces. See: https://en.wikipedia.org/wiki/Whitespace_character#Unicode
//
// <string>.trim() -> <string>
// <string>.trim() -> <string>
//
// Examples:
//
// ' \ttrim\n '.trim() // returns 'trim'
// ' \ttrim\n '.trim() // returns 'trim'
//
// UpperAscii
// # UpperAscii
//
// Returns a new string where all ASCII characters are upper-cased.
//
// This function does not perform Unicode case-mapping for characters outside the ASCII range.
//
// <string>.upperAscii() -> <string>
// <string>.upperAscii() -> <string>
//
// Examples:
//
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
func Strings() cel.EnvOption {
return cel.Lib(stringLib{})
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
func Strings(options ...StringsOption) cel.EnvOption {
s := &stringLib{version: math.MaxUint32}
for _, o := range options {
s = o(s)
}
return cel.Lib(s)
}
type stringLib struct{}
type stringLib struct {
locale string
version uint32
}
func (stringLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
// LibraryName implements the SingletonLibrary interface method.
func (*stringLib) LibraryName() string {
return "cel.lib.ext.strings"
}
// StringsOption is a functional interface for configuring the strings library.
type StringsOption func(*stringLib) *stringLib
// StringsLocale configures the library with the given locale. The locale tag will
// be checked for validity at the time that EnvOptions are configured. If this option
// is not passed, string.format will behave as if en_US was passed as the locale.
func StringsLocale(locale string) StringsOption {
return func(sl *stringLib) *stringLib {
sl.locale = locale
return sl
}
}
// StringsVersion configures the version of the string library. The version limits which
// functions are available. Only functions introduced below or equal to the given
// version included in the library. See the library documentation to determine
// which version a function was introduced at. If the documentation does not
// state which version a function was introduced at, it can be assumed to be
// introduced at version 0, when the library was first created.
// If this option is not set, all functions are available.
func StringsVersion(version uint32) func(lib *stringLib) *stringLib {
return func(sl *stringLib) *stringLib {
sl.version = version
return sl
}
}
// CompileOptions implements the Library interface method.
func (sl *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
if sl.locale != "" {
// ensure locale is properly-formed if set
_, err := language.Parse(sl.locale)
if err != nil {
return []cel.EnvOption{
func(e *cel.Env) (*cel.Env, error) {
return nil, fmt.Errorf("failed to parse locale: %w", err)
},
}
}
formatLocale = sl.locale
}
opts := []cel.EnvOption{
cel.Function("charAt",
cel.MemberOverload("string_char_at_int", []*cel.Type{cel.StringType, cel.IntType}, cel.StringType,
cel.BinaryBinding(func(str, ind ref.Val) ref.Val {
@ -303,28 +431,64 @@ func (stringLib) CompileOptions() []cel.EnvOption {
s := str.(types.String)
return stringOrError(upperASCII(string(s)))
}))),
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
return stringOrError(join(l.([]string)))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
}
if sl.version >= 1 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(interpreter.ParseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))),
cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(quote(string(s)))
}))))
}
if sl.version >= 2 {
opts = append(opts,
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
return stringOrError(joinValSeparator(l, ""))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l := list.(traits.Lister)
d := delim.(types.String)
return stringOrError(joinValSeparator(l, string(d)))
}))),
)
} else {
opts = append(opts,
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
return stringOrError(join(l.([]string)))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
)
}
return opts
}
func (stringLib) ProgramOptions() []cel.ProgramOption {
// ProgramOptions implements the Library interface method.
func (*stringLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
@ -478,6 +642,452 @@ func join(strs []string) (string, error) {
return strings.Join(strs, ""), nil
}
func joinValSeparator(strs traits.Lister, separator string) (string, error) {
sz := strs.Size().(types.Int)
var sb strings.Builder
for i := types.Int(0); i < sz; i++ {
if i != 0 {
sb.WriteString(separator)
}
elem := strs.Get(i)
str, ok := elem.(types.String)
if !ok {
return "", fmt.Errorf("join: invalid input: %v", elem)
}
sb.WriteString(string(str))
}
return sb.String(), nil
}
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner
// suitable for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
// It is used to implement the %s specifier in the (string).format() extension
// function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", fmt.Errorf("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", fmt.Errorf("decimal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
// quote implements a string quoting function. The string will be wrapped in
// double quotes, and all valid CEL escape sequences will be escaped to show up
// literally if printed. If the input contains any invalid UTF-8, the invalid runes
// will be replaced with utf8.RuneError.
func quote(s string) (string, error) {
var quotedStrBuilder strings.Builder
for _, c := range sanitize(s) {
switch c {
case '\a':
quotedStrBuilder.WriteString("\\a")
case '\b':
quotedStrBuilder.WriteString("\\b")
case '\f':
quotedStrBuilder.WriteString("\\f")
case '\n':
quotedStrBuilder.WriteString("\\n")
case '\r':
quotedStrBuilder.WriteString("\\r")
case '\t':
quotedStrBuilder.WriteString("\\t")
case '\v':
quotedStrBuilder.WriteString("\\v")
case '\\':
quotedStrBuilder.WriteString("\\\\")
case '"':
quotedStrBuilder.WriteString("\\\"")
default:
quotedStrBuilder.WriteRune(c)
}
}
escapedStr := quotedStrBuilder.String()
return "\"" + escapedStr + "\"", nil
}
// sanitize replaces all invalid runes in the given string with utf8.RuneError.
func sanitize(s string) string {
var sanitizedStringBuilder strings.Builder
for _, r := range s {
if !utf8.ValidRune(r) {
sanitizedStringBuilder.WriteRune(utf8.RuneError)
} else {
sanitizedStringBuilder.WriteRune(r)
}
}
return sanitizedStringBuilder.String()
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("fixed-point clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("scientific clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", fmt.Errorf("only integers and bools can be formatted as binary, was given %s", arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", fmt.Errorf("only integers, byte buffers, and strings can be formatted as hex, was given %s", arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", fmt.Errorf("octal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) ArgSize() int64 {
return c.args.Size().Value().(int64)
}
var (
stringListType = reflect.TypeOf([]string{})
)

View File

@ -11,10 +11,10 @@ go_library(
"activation.go",
"attribute_patterns.go",
"attributes.go",
"coster.go",
"decorators.go",
"dispatcher.go",
"evalstate.go",
"formatting.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",
@ -32,7 +32,7 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter/functions:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@ -49,6 +49,7 @@ go_test(
"attributes_test.go",
"interpreter_test.go",
"prune_test.go",
"runtimecost_test.go",
],
embed = [
":go_default_library",
@ -65,7 +66,7 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
],

View File

@ -28,7 +28,7 @@ import (
type Activation interface {
// ResolveName returns a value from the activation by qualified name, or false if the name
// could not be found.
ResolveName(name string) (interface{}, bool)
ResolveName(name string) (any, bool)
// Parent returns the parent of the current activation, may be nil.
// If non-nil, the parent will be searched during resolve calls.
@ -43,23 +43,23 @@ func EmptyActivation() Activation {
// emptyActivation is a variable-free activation.
type emptyActivation struct{}
func (emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false }
func (emptyActivation) Parent() Activation { return nil }
func (emptyActivation) ResolveName(string) (any, bool) { return nil, false }
func (emptyActivation) Parent() Activation { return nil }
// NewActivation returns an activation based on a map-based binding where the map keys are
// expected to be qualified names used with ResolveName calls.
//
// The input `bindings` may either be of type `Activation` or `map[string]interface{}`.
// The input `bindings` may either be of type `Activation` or `map[string]any`.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() any
// - func() ref.Val
//
// The output of the lazy binding will overwrite the variable reference in the internal map.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func NewActivation(bindings interface{}) (Activation, error) {
func NewActivation(bindings any) (Activation, error) {
if bindings == nil {
return nil, errors.New("bindings must be non-nil")
}
@ -67,7 +67,7 @@ func NewActivation(bindings interface{}) (Activation, error) {
if isActivation {
return a, nil
}
m, isMap := bindings.(map[string]interface{})
m, isMap := bindings.(map[string]any)
if !isMap {
return nil, fmt.Errorf(
"activation input must be an activation or map[string]interface: got %T",
@ -81,7 +81,7 @@ func NewActivation(bindings interface{}) (Activation, error) {
// Named bindings may lazily supply values by providing a function which accepts no arguments and
// produces an interface value.
type mapActivation struct {
bindings map[string]interface{}
bindings map[string]any
}
// Parent implements the Activation interface method.
@ -90,7 +90,7 @@ func (a *mapActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (a *mapActivation) ResolveName(name string) (interface{}, bool) {
func (a *mapActivation) ResolveName(name string) (any, bool) {
obj, found := a.bindings[name]
if !found {
return nil, false
@ -100,7 +100,7 @@ func (a *mapActivation) ResolveName(name string) (interface{}, bool) {
obj = fn()
a.bindings[name] = obj
}
fnRaw, isLazy := obj.(func() interface{})
fnRaw, isLazy := obj.(func() any)
if isLazy {
obj = fnRaw()
a.bindings[name] = obj
@ -121,7 +121,7 @@ func (a *hierarchicalActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (a *hierarchicalActivation) ResolveName(name string) (interface{}, bool) {
func (a *hierarchicalActivation) ResolveName(name string) (any, bool) {
if object, found := a.child.ResolveName(name); found {
return object, found
}
@ -138,8 +138,8 @@ func NewHierarchicalActivation(parent Activation, child Activation) Activation {
// representing field and index operations that should result in a 'types.Unknown' result.
//
// The `bindings` value may be any value type supported by the interpreter.NewActivation call,
// but is typically either an existing Activation or map[string]interface{}.
func NewPartialActivation(bindings interface{},
// but is typically either an existing Activation or map[string]any.
func NewPartialActivation(bindings any,
unknowns ...*AttributePattern) (PartialActivation, error) {
a, err := NewActivation(bindings)
if err != nil {
@ -184,7 +184,7 @@ func (v *varActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (v *varActivation) ResolveName(name string) (interface{}, bool) {
func (v *varActivation) ResolveName(name string) (any, bool) {
if name == v.name {
return v.val, true
}
@ -194,7 +194,7 @@ func (v *varActivation) ResolveName(name string) (interface{}, bool) {
var (
// pool of var activations to reduce allocations during folds.
varActivationPool = &sync.Pool{
New: func() interface{} {
New: func() any {
return &varActivation{}
},
}

View File

@ -15,8 +15,6 @@
package interpreter
import (
"fmt"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
@ -36,9 +34,9 @@ import (
//
// Examples:
//
// 1. ns.myvar["complex-value"]
// 2. ns.myvar["complex-value"][0]
// 3. ns.myvar["complex-value"].*.name
// 1. ns.myvar["complex-value"]
// 2. ns.myvar["complex-value"][0]
// 3. ns.myvar["complex-value"].*.name
//
// The first example is simple: match an attribute where the variable is 'ns.myvar' with a
// field access on 'complex-value'. The second example expands the match to indicate that only
@ -108,7 +106,7 @@ func (apat *AttributePattern) QualifierPatterns() []*AttributeQualifierPattern {
// AttributeQualifierPattern holds a wildcard or valued qualifier pattern.
type AttributeQualifierPattern struct {
wildcard bool
value interface{}
value any
}
// Matches returns true if the qualifier pattern is a wildcard, or the Qualifier implements the
@ -134,44 +132,44 @@ func (qpat *AttributeQualifierPattern) Matches(q Qualifier) bool {
type qualifierValueEquator interface {
// QualifierValueEquals returns true if the input value is equal to the value held in the
// Qualifier.
QualifierValueEquals(value interface{}) bool
QualifierValueEquals(value any) bool
}
// QualifierValueEquals implementation for boolean qualifiers.
func (q *boolQualifier) QualifierValueEquals(value interface{}) bool {
func (q *boolQualifier) QualifierValueEquals(value any) bool {
bval, ok := value.(bool)
return ok && q.value == bval
}
// QualifierValueEquals implementation for field qualifiers.
func (q *fieldQualifier) QualifierValueEquals(value interface{}) bool {
func (q *fieldQualifier) QualifierValueEquals(value any) bool {
sval, ok := value.(string)
return ok && q.Name == sval
}
// QualifierValueEquals implementation for string qualifiers.
func (q *stringQualifier) QualifierValueEquals(value interface{}) bool {
func (q *stringQualifier) QualifierValueEquals(value any) bool {
sval, ok := value.(string)
return ok && q.value == sval
}
// QualifierValueEquals implementation for int qualifiers.
func (q *intQualifier) QualifierValueEquals(value interface{}) bool {
func (q *intQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for uint qualifiers.
func (q *uintQualifier) QualifierValueEquals(value interface{}) bool {
func (q *uintQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for double qualifiers.
func (q *doubleQualifier) QualifierValueEquals(value interface{}) bool {
func (q *doubleQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// numericValueEquals uses CEL equality to determine whether two number values are
func numericValueEquals(value interface{}, celValue ref.Val) bool {
func numericValueEquals(value any, celValue ref.Val) bool {
val := types.DefaultTypeAdapter.NativeToValue(value)
return celValue.Equal(val) == types.True
}
@ -272,13 +270,9 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
if err != nil {
return nil, err
}
unk, isUnk := val.(types.Unknown)
if isUnk {
return unk, nil
}
// If this resolution behavior ever changes, new implementations of the
// qualifierValueEquator may be required to handle proper resolution.
qual, err = fac.NewQualifier(nil, qual.ID(), val)
qual, err = fac.NewQualifier(nil, qual.ID(), val, attr.IsOptional())
if err != nil {
return nil, err
}
@ -338,24 +332,10 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) {
return m, nil
}
// Resolve is an implementation of the Attribute interface method which uses the
// attributeMatcher TryResolve implementation rather than the embedded NamespacedAttribute
// Resolve implementation.
func (m *attributeMatcher) Resolve(vars Activation) (interface{}, error) {
obj, found, err := m.TryResolve(vars)
if err != nil {
return nil, err
}
if !found {
return nil, fmt.Errorf("no such attribute: %v", m.NamespacedAttribute)
}
return obj, nil
}
// TryResolve is an implementation of the NamespacedAttribute interface method which tests
// Resolve is an implementation of the NamespacedAttribute interface method which tests
// for matching unknown attribute patterns and returns types.Unknown if present. Otherwise,
// the standard Resolve logic applies.
func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error) {
func (m *attributeMatcher) Resolve(vars Activation) (any, error) {
id := m.NamespacedAttribute.ID()
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := toPartialActivation(vars)
@ -366,30 +346,23 @@ func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error
m.CandidateVariableNames(),
m.qualifiers)
if err != nil {
return nil, true, err
return nil, err
}
if unk != nil {
return unk, true, nil
return unk, nil
}
}
return m.NamespacedAttribute.TryResolve(vars)
return m.NamespacedAttribute.Resolve(vars)
}
// Qualify is an implementation of the Qualifier interface method.
func (m *attributeMatcher) Qualify(vars Activation, obj interface{}) (interface{}, error) {
val, err := m.Resolve(vars)
if err != nil {
return nil, err
}
unk, isUnk := val.(types.Unknown)
if isUnk {
return unk, nil
}
qual, err := m.fac.NewQualifier(nil, m.ID(), val)
if err != nil {
return nil, err
}
return qual.Qualify(vars, obj)
func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) {
return attrQualify(m.fac, vars, obj, m)
}
// QualifyIfPresent is an implementation of the Qualifier interface method.
func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly)
}
func toPartialActivation(vars Activation) (PartialActivation, bool) {

File diff suppressed because it is too large Load Diff

View File

@ -29,7 +29,7 @@ type InterpretableDecorator func(Interpretable) (Interpretable, error)
func decObserveEval(observer EvalObserver) InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
switch inst := i.(type) {
case *evalWatch, *evalWatchAttr, *evalWatchConst:
case *evalWatch, *evalWatchAttr, *evalWatchConst, *evalWatchConstructor:
// these instruction are already watching, return straight-away.
return i, nil
case InterpretableAttribute:
@ -42,6 +42,11 @@ func decObserveEval(observer EvalObserver) InterpretableDecorator {
InterpretableConst: inst,
observer: observer,
}, nil
case InterpretableConstructor:
return &evalWatchConstructor{
constructor: inst,
observer: observer,
}, nil
default:
return &evalWatch{
Interpretable: i,
@ -224,8 +229,8 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
valueSet := make(map[ref.Val]ref.Val)
for it.HasNext() == types.True {
elem := it.Next()
if !types.IsPrimitiveType(elem) {
// Note, non-primitive type are not yet supported.
if !types.IsPrimitiveType(elem) || elem.Type() == types.BytesType {
// Note, non-primitive type are not yet supported, and []byte isn't hashable.
return i, nil
}
valueSet[elem] = types.True

View File

@ -0,0 +1,383 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type typeVerifier func(int64, ...*types.TypeValue) (bool, error)
// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports
// any errors at compile time.
func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator {
return func(inter Interpretable) (Interpretable, error) {
call, ok := inter.(InterpretableCall)
if !ok {
return inter, nil
}
if call.OverloadID() != "string_format" {
return inter, nil
}
args := call.Args()
if len(args) != 2 {
return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args))
}
fmtStrInter, ok := args[0].(InterpretableConst)
if !ok {
return inter, nil
}
var fmtArgsInter InterpretableConstructor
fmtArgsInter, ok = args[1].(InterpretableConstructor)
if !ok {
return inter, nil
}
if fmtArgsInter.Type() != types.ListType {
// don't necessarily return an error since the list may be DynType
return inter, nil
}
formatStr := fmtStrInter.Value().Value().(string)
initVals := fmtArgsInter.InitVals()
formatCheck := &formatCheck{
args: initVals,
verifier: verifier,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
return nil, err
}
seenArgs := formatCheck.argsRequested
if len(initVals) > seenArgs {
return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals))
}
return inter, nil
}
}
type formatCheck struct {
args []Interpretable
argsRequested int
curArgIndex int64
enableCheckArgTypes bool
verifier typeVerifier
}
func (c *formatCheck) String(arg ref.Val, locale string) (string, error) {
valid, err := verifyString(c.args[c.curArgIndex], c.verifier)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps")
}
return "", nil
}
func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("integer clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("fixed-point clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("scientific clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers and bools can be formatted as binary")
}
return "", nil
}
func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers, byte buffers, and strings can be formatted as hex")
}
return "", nil
}
}
func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("octal clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.curArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *formatCheck) ArgSize() int64 {
return int64(len(c.args))
}
func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) {
subVerified, err := verifier(sub.ID(),
types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType)
if err != nil {
return false, err
}
if !subVerified {
return false, nil
}
con, ok := sub.(InterpretableConstructor)
if ok {
members := con.InitVals()
for _, m := range members {
// recursively verify if we're dealing with a list/map
verified, err := verifyString(m, verifier)
if err != nil {
return false, err
}
if !verified {
return false, nil
}
}
}
return true, nil
}
// FormatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type FormatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// FormatList is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type FormatList interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// ArgSize returns the length of the argument list.
ArgSize() int64
}
type clauseImpl func(ref.Val, string) (string, error)
// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.ArgSize() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", fmt.Errorf("could not parse formatting clause: %s", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", fmt.Errorf("error during formatting: %s", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}

View File

@ -58,5 +58,5 @@ type UnaryOp func(value ref.Val) ref.Val
type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val
// FunctionOp is a function with accepts zero or more arguments and produces
// an value (as interface{}) or error as a result.
// a value or error as a result.
type FunctionOp func(values ...ref.Val) ref.Val

View File

@ -15,7 +15,7 @@
package interpreter
import (
"math"
"fmt"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
@ -64,10 +64,18 @@ type InterpretableAttribute interface {
// Qualify replicates the Attribute.Qualify method to permit extension and interception
// of object qualification.
Qualify(vars Activation, obj interface{}) (interface{}, error)
Qualify(vars Activation, obj any) (any, error)
// QualifyIfPresent qualifies the object if the qualifier is declared or defined on the object.
// The 'presenceOnly' flag indicates that the value is not necessary, just a boolean status as
// to whether the qualifier is present.
QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error)
// IsOptional indicates whether the resulting value is an optional type.
IsOptional() bool
// Resolve returns the value of the Attribute given the current Activation.
Resolve(Activation) (interface{}, error)
Resolve(Activation) (any, error)
}
// InterpretableCall interface for inspecting Interpretable instructions related to function calls.
@ -103,10 +111,8 @@ type InterpretableConstructor interface {
// Core Interpretable implementations used during the program planning phase.
type evalTestOnly struct {
id int64
op Interpretable
field types.String
fieldType *ref.FieldType
id int64
InterpretableAttribute
}
// ID implements the Interpretable interface method.
@ -116,44 +122,55 @@ func (test *evalTestOnly) ID() int64 {
// Eval implements the Interpretable interface method.
func (test *evalTestOnly) Eval(ctx Activation) ref.Val {
// Handle field selection on a proto in the most efficient way possible.
if test.fieldType != nil {
opAttr, ok := test.op.(InterpretableAttribute)
if ok {
opVal, err := opAttr.Resolve(ctx)
if err != nil {
return types.NewErr(err.Error())
}
refVal, ok := opVal.(ref.Val)
if ok {
opVal = refVal.Value()
}
if test.fieldType.IsSet(opVal) {
return types.True
}
return types.False
}
val, err := test.Resolve(ctx)
// Return an error if the resolve step fails
if err != nil {
return types.WrapErr(err)
}
obj := test.op.Eval(ctx)
tester, ok := obj.(traits.FieldTester)
if ok {
return tester.IsSet(test.field)
if optVal, isOpt := val.(*types.Optional); isOpt {
return types.Bool(optVal.HasValue())
}
container, ok := obj.(traits.Container)
if ok {
return container.Contains(test.field)
}
return types.ValOrErr(obj, "invalid type for field selection.")
return test.Adapter().NativeToValue(val)
}
// Cost provides the heuristic cost of a `has(field)` macro. The cost has at least 1 for determining
// if the field exists, apart from the cost of accessing the field.
func (test *evalTestOnly) Cost() (min, max int64) {
min, max = estimateCost(test.op)
min++
max++
return
// AddQualifier appends a qualifier that will always and only perform a presence test.
func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) {
cq, ok := q.(ConstantQualifier)
if !ok {
return nil, fmt.Errorf("test only expressions must have constant qualifiers: %v", q)
}
return test.InterpretableAttribute.AddQualifier(&testOnlyQualifier{ConstantQualifier: cq})
}
type testOnlyQualifier struct {
ConstantQualifier
}
// Qualify determines whether the test-only qualifier is present on the input object.
func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) {
out, present, err := q.ConstantQualifier.QualifyIfPresent(vars, obj, true)
if err != nil {
return nil, err
}
if unk, isUnk := out.(types.Unknown); isUnk {
return unk, nil
}
if opt, isOpt := out.(types.Optional); isOpt {
return opt.HasValue(), nil
}
return present, nil
}
// QualifyIfPresent returns whether the target field in the test-only expression is present.
func (q *testOnlyQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
// Only ever test for presence.
return q.ConstantQualifier.QualifyIfPresent(vars, obj, true)
}
// QualifierValueEquals determines whether the test-only constant qualifier equals the input value.
func (q *testOnlyQualifier) QualifierValueEquals(value any) bool {
// The input qualifier will always be of type string
return q.ConstantQualifier.Value().Value() == value
}
// NewConstValue creates a new constant valued Interpretable.
@ -179,11 +196,6 @@ func (cons *evalConst) Eval(ctx Activation) ref.Val {
return cons.val
}
// Cost returns zero for a constant valued Interpretable.
func (cons *evalConst) Cost() (min, max int64) {
return 0, 0
}
// Value implements the InterpretableConst interface method.
func (cons *evalConst) Value() ref.Val {
return cons.val
@ -233,12 +245,6 @@ func (or *evalOr) Eval(ctx Activation) ref.Val {
return types.ValOrErr(rVal, "no such overload")
}
// Cost implements the Coster interface method. The minimum possible cost incurs when the left-hand
// side expr is sufficient in determining the evaluation result.
func (or *evalOr) Cost() (min, max int64) {
return calShortCircuitBinaryOpsCost(or.lhs, or.rhs)
}
type evalAnd struct {
id int64
lhs Interpretable
@ -283,18 +289,6 @@ func (and *evalAnd) Eval(ctx Activation) ref.Val {
return types.ValOrErr(rVal, "no such overload")
}
// Cost implements the Coster interface method. The minimum possible cost incurs when the left-hand
// side expr is sufficient in determining the evaluation result.
func (and *evalAnd) Cost() (min, max int64) {
return calShortCircuitBinaryOpsCost(and.lhs, and.rhs)
}
func calShortCircuitBinaryOpsCost(lhs, rhs Interpretable) (min, max int64) {
lMin, lMax := estimateCost(lhs)
_, rMax := estimateCost(rhs)
return lMin, lMax + rMax + 1
}
type evalEq struct {
id int64
lhs Interpretable
@ -319,11 +313,6 @@ func (eq *evalEq) Eval(ctx Activation) ref.Val {
return types.Equal(lVal, rVal)
}
// Cost implements the Coster interface method.
func (eq *evalEq) Cost() (min, max int64) {
return calExhaustiveBinaryOpsCost(eq.lhs, eq.rhs)
}
// Function implements the InterpretableCall interface method.
func (*evalEq) Function() string {
return operators.Equals
@ -363,11 +352,6 @@ func (ne *evalNe) Eval(ctx Activation) ref.Val {
return types.Bool(types.Equal(lVal, rVal) != types.True)
}
// Cost implements the Coster interface method.
func (ne *evalNe) Cost() (min, max int64) {
return calExhaustiveBinaryOpsCost(ne.lhs, ne.rhs)
}
// Function implements the InterpretableCall interface method.
func (*evalNe) Function() string {
return operators.NotEquals
@ -400,11 +384,6 @@ func (zero *evalZeroArity) Eval(ctx Activation) ref.Val {
return zero.impl()
}
// Cost returns 1 representing the heuristic cost of the function.
func (zero *evalZeroArity) Cost() (min, max int64) {
return 1, 1
}
// Function implements the InterpretableCall interface method.
func (zero *evalZeroArity) Function() string {
return zero.function
@ -456,14 +435,6 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val {
return types.NewErr("no such overload: %s", un.function)
}
// Cost implements the Coster interface method.
func (un *evalUnary) Cost() (min, max int64) {
min, max = estimateCost(un.arg)
min++ // add cost for function
max++
return
}
// Function implements the InterpretableCall interface method.
func (un *evalUnary) Function() string {
return un.function
@ -522,11 +493,6 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val {
return types.NewErr("no such overload: %s", bin.function)
}
// Cost implements the Coster interface method.
func (bin *evalBinary) Cost() (min, max int64) {
return calExhaustiveBinaryOpsCost(bin.lhs, bin.rhs)
}
// Function implements the InterpretableCall interface method.
func (bin *evalBinary) Function() string {
return bin.function
@ -593,14 +559,6 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val {
return types.NewErr("no such overload: %s", fn.function)
}
// Cost implements the Coster interface method.
func (fn *evalVarArgs) Cost() (min, max int64) {
min, max = sumOfCost(fn.args)
min++ // add cost for function
max++
return
}
// Function implements the InterpretableCall interface method.
func (fn *evalVarArgs) Function() string {
return fn.function
@ -617,9 +575,11 @@ func (fn *evalVarArgs) Args() []Interpretable {
}
type evalList struct {
id int64
elems []Interpretable
adapter ref.TypeAdapter
id int64
elems []Interpretable
optionals []bool
hasOptionals bool
adapter ref.TypeAdapter
}
// ID implements the Interpretable interface method.
@ -629,14 +589,24 @@ func (l *evalList) ID() int64 {
// Eval implements the Interpretable interface method.
func (l *evalList) Eval(ctx Activation) ref.Val {
elemVals := make([]ref.Val, len(l.elems))
elemVals := make([]ref.Val, 0, len(l.elems))
// If any argument is unknown or error early terminate.
for i, elem := range l.elems {
elemVal := elem.Eval(ctx)
if types.IsUnknownOrError(elemVal) {
return elemVal
}
elemVals[i] = elemVal
if l.hasOptionals && l.optionals[i] {
optVal, ok := elemVal.(*types.Optional)
if !ok {
return invalidOptionalElementInit(elemVal)
}
if !optVal.HasValue() {
continue
}
elemVal = optVal.GetValue()
}
elemVals = append(elemVals, elemVal)
}
return l.adapter.NativeToValue(elemVals)
}
@ -649,16 +619,13 @@ func (l *evalList) Type() ref.Type {
return types.ListType
}
// Cost implements the Coster interface method.
func (l *evalList) Cost() (min, max int64) {
return sumOfCost(l.elems)
}
type evalMap struct {
id int64
keys []Interpretable
vals []Interpretable
adapter ref.TypeAdapter
id int64
keys []Interpretable
vals []Interpretable
optionals []bool
hasOptionals bool
adapter ref.TypeAdapter
}
// ID implements the Interpretable interface method.
@ -679,6 +646,17 @@ func (m *evalMap) Eval(ctx Activation) ref.Val {
if types.IsUnknownOrError(valVal) {
return valVal
}
if m.hasOptionals && m.optionals[i] {
optVal, ok := valVal.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(keyVal, valVal)
}
if !optVal.HasValue() {
delete(entries, keyVal)
continue
}
valVal = optVal.GetValue()
}
entries[keyVal] = valVal
}
return m.adapter.NativeToValue(entries)
@ -704,19 +682,14 @@ func (m *evalMap) Type() ref.Type {
return types.MapType
}
// Cost implements the Coster interface method.
func (m *evalMap) Cost() (min, max int64) {
kMin, kMax := sumOfCost(m.keys)
vMin, vMax := sumOfCost(m.vals)
return kMin + vMin, kMax + vMax
}
type evalObj struct {
id int64
typeName string
fields []string
vals []Interpretable
provider ref.TypeProvider
id int64
typeName string
fields []string
vals []Interpretable
optionals []bool
hasOptionals bool
provider ref.TypeProvider
}
// ID implements the Interpretable interface method.
@ -733,6 +706,17 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
if types.IsUnknownOrError(val) {
return val
}
if o.hasOptionals && o.optionals[i] {
optVal, ok := val.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(field, val)
}
if !optVal.HasValue() {
delete(fieldVals, field)
continue
}
val = optVal.GetValue()
}
fieldVals[field] = val
}
return o.provider.NewValue(o.typeName, fieldVals)
@ -746,21 +730,6 @@ func (o *evalObj) Type() ref.Type {
return types.NewObjectTypeValue(o.typeName)
}
// Cost implements the Coster interface method.
func (o *evalObj) Cost() (min, max int64) {
return sumOfCost(o.vals)
}
func sumOfCost(interps []Interpretable) (min, max int64) {
min, max = 0, 0
for _, in := range interps {
minT, maxT := estimateCost(in)
min += minT
max += maxT
}
return
}
type evalFold struct {
id int64
accuVar string
@ -842,38 +811,6 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
return res
}
// Cost implements the Coster interface method.
func (fold *evalFold) Cost() (min, max int64) {
// Compute the cost for evaluating iterRange.
iMin, iMax := estimateCost(fold.iterRange)
// Compute the size of iterRange. If the size depends on the input, return the maximum possible
// cost range.
foldRange := fold.iterRange.Eval(EmptyActivation())
if !foldRange.Type().HasTrait(traits.IterableType) {
return 0, math.MaxInt64
}
var rangeCnt int64
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
it.Next()
rangeCnt++
}
aMin, aMax := estimateCost(fold.accu)
cMin, cMax := estimateCost(fold.cond)
sMin, sMax := estimateCost(fold.step)
rMin, rMax := estimateCost(fold.result)
if fold.exhaustive {
cMin = cMin * rangeCnt
sMin = sMin * rangeCnt
}
// The cond and step costs are multiplied by size(iterRange). The minimum possible cost incurs
// when the evaluation result can be determined by the first iteration.
return iMin + aMin + cMin + sMin + rMin,
iMax + aMax + cMax*rangeCnt + sMax*rangeCnt + rMax
}
// Optional Interpretable implementations that specialize, subsume, or extend the core evaluation
// plan via decorators.
@ -893,17 +830,15 @@ func (e *evalSetMembership) ID() int64 {
// Eval implements the Interpretable interface method.
func (e *evalSetMembership) Eval(ctx Activation) ref.Val {
val := e.arg.Eval(ctx)
if types.IsUnknownOrError(val) {
return val
}
if ret, found := e.valueSet[val]; found {
return ret
}
return types.False
}
// Cost implements the Coster interface method.
func (e *evalSetMembership) Cost() (min, max int64) {
return estimateCost(e.arg)
}
// evalWatch is an Interpretable implementation that wraps the execution of a given
// expression so that it may observe the computed value and send it to an observer.
type evalWatch struct {
@ -918,15 +853,10 @@ func (e *evalWatch) Eval(ctx Activation) ref.Val {
return val
}
// Cost implements the Coster interface method.
func (e *evalWatch) Cost() (min, max int64) {
return estimateCost(e.Interpretable)
}
// evalWatchAttr describes a watcher of an instAttr Interpretable.
// evalWatchAttr describes a watcher of an InterpretableAttribute Interpretable.
//
// Since the watcher may be selected against at a later stage in program planning, the watcher
// must implement the instAttr interface by proxy.
// must implement the InterpretableAttribute interface by proxy.
type evalWatchAttr struct {
InterpretableAttribute
observer EvalObserver
@ -953,11 +883,6 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) {
return e, err
}
// Cost implements the Coster interface method.
func (e *evalWatchAttr) Cost() (min, max int64) {
return estimateCost(e.InterpretableAttribute)
}
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
@ -973,17 +898,12 @@ type evalWatchConstQual struct {
adapter ref.TypeAdapter
}
// Cost implements the Coster interface method.
func (e *evalWatchConstQual) Cost() (min, max int64) {
return estimateCost(e.ConstantQualifier)
}
// Qualify observes the qualification of a object via a constant boolean, int, string, or uint.
func (e *evalWatchConstQual) Qualify(vars Activation, obj interface{}) (interface{}, error) {
func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.ConstantQualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.NewErr(err.Error())
val = types.WrapErr(err)
} else {
val = e.adapter.NativeToValue(out)
}
@ -991,8 +911,25 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj interface{}) (interfac
return out, err
}
// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present.
func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.ConstantQualifier, val)
}
return out, present, err
}
// QualifierValueEquals tests whether the incoming value is equal to the qualifying constant.
func (e *evalWatchConstQual) QualifierValueEquals(value interface{}) bool {
func (e *evalWatchConstQual) QualifierValueEquals(value any) bool {
qve, ok := e.ConstantQualifier.(qualifierValueEquator)
return ok && qve.QualifierValueEquals(value)
}
@ -1004,17 +941,12 @@ type evalWatchQual struct {
adapter ref.TypeAdapter
}
// Cost implements the Coster interface method.
func (e *evalWatchQual) Cost() (min, max int64) {
return estimateCost(e.Qualifier)
}
// Qualify observes the qualification of a object via a value computed at runtime.
func (e *evalWatchQual) Qualify(vars Activation, obj interface{}) (interface{}, error) {
func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Qualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.NewErr(err.Error())
val = types.WrapErr(err)
} else {
val = e.adapter.NativeToValue(out)
}
@ -1022,6 +954,23 @@ func (e *evalWatchQual) Qualify(vars Activation, obj interface{}) (interface{},
return out, err
}
// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present.
func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.Qualifier, val)
}
return out, present, err
}
// evalWatchConst describes a watcher of an instConst Interpretable.
type evalWatchConst struct {
InterpretableConst
@ -1035,11 +984,6 @@ func (e *evalWatchConst) Eval(vars Activation) ref.Val {
return val
}
// Cost implements the Coster interface method.
func (e *evalWatchConst) Cost() (min, max int64) {
return estimateCost(e.InterpretableConst)
}
// evalExhaustiveOr is just like evalOr, but does not short-circuit argument evaluation.
type evalExhaustiveOr struct {
id int64
@ -1078,12 +1022,7 @@ func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val {
if types.IsError(lVal) {
return lVal
}
return types.ValOrErr(rVal, "no such overload")
}
// Cost implements the Coster interface method.
func (or *evalExhaustiveOr) Cost() (min, max int64) {
return calExhaustiveBinaryOpsCost(or.lhs, or.rhs)
return types.MaybeNoSuchOverloadErr(rVal)
}
// evalExhaustiveAnd is just like evalAnd, but does not short-circuit argument evaluation.
@ -1124,18 +1063,7 @@ func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val {
if types.IsError(lVal) {
return lVal
}
return types.ValOrErr(rVal, "no such overload")
}
// Cost implements the Coster interface method.
func (and *evalExhaustiveAnd) Cost() (min, max int64) {
return calExhaustiveBinaryOpsCost(and.lhs, and.rhs)
}
func calExhaustiveBinaryOpsCost(lhs, rhs Interpretable) (min, max int64) {
lMin, lMax := estimateCost(lhs)
rMin, rMax := estimateCost(rhs)
return lMin + rMin + 1, lMax + rMax + 1
return types.MaybeNoSuchOverloadErr(rVal)
}
// evalExhaustiveConditional is like evalConditional, but does not short-circuit argument
@ -1154,77 +1082,114 @@ func (cond *evalExhaustiveConditional) ID() int64 {
// Eval implements the Interpretable interface method.
func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val {
cVal := cond.attr.expr.Eval(ctx)
tVal, err := cond.attr.truthy.Resolve(ctx)
if err != nil {
return types.NewErr(err.Error())
}
fVal, err := cond.attr.falsy.Resolve(ctx)
if err != nil {
return types.NewErr(err.Error())
}
tVal, tErr := cond.attr.truthy.Resolve(ctx)
fVal, fErr := cond.attr.falsy.Resolve(ctx)
cBool, ok := cVal.(types.Bool)
if !ok {
return types.ValOrErr(cVal, "no such overload")
}
if cBool {
if tErr != nil {
return types.WrapErr(tErr)
}
return cond.adapter.NativeToValue(tVal)
}
if fErr != nil {
return types.WrapErr(fErr)
}
return cond.adapter.NativeToValue(fVal)
}
// Cost implements the Coster interface method.
func (cond *evalExhaustiveConditional) Cost() (min, max int64) {
return cond.attr.Cost()
}
// evalAttr evaluates an Attribute value.
type evalAttr struct {
adapter ref.TypeAdapter
attr Attribute
adapter ref.TypeAdapter
attr Attribute
optional bool
}
var _ InterpretableAttribute = &evalAttr{}
// ID of the attribute instruction.
func (a *evalAttr) ID() int64 {
return a.attr.ID()
}
// AddQualifier implements the instAttr interface method.
// AddQualifier implements the InterpretableAttribute interface method.
func (a *evalAttr) AddQualifier(qual Qualifier) (Attribute, error) {
attr, err := a.attr.AddQualifier(qual)
a.attr = attr
return attr, err
}
// Attr implements the instAttr interface method.
// Attr implements the InterpretableAttribute interface method.
func (a *evalAttr) Attr() Attribute {
return a.attr
}
// Adapter implements the instAttr interface method.
// Adapter implements the InterpretableAttribute interface method.
func (a *evalAttr) Adapter() ref.TypeAdapter {
return a.adapter
}
// Cost implements the Coster interface method.
func (a *evalAttr) Cost() (min, max int64) {
return estimateCost(a.attr)
}
// Eval implements the Interpretable interface method.
func (a *evalAttr) Eval(ctx Activation) ref.Val {
v, err := a.attr.Resolve(ctx)
if err != nil {
return types.NewErr(err.Error())
return types.WrapErr(err)
}
return a.adapter.NativeToValue(v)
}
// Qualify proxies to the Attribute's Qualify method.
func (a *evalAttr) Qualify(ctx Activation, obj interface{}) (interface{}, error) {
func (a *evalAttr) Qualify(ctx Activation, obj any) (any, error) {
return a.attr.Qualify(ctx, obj)
}
// QualifyIfPresent proxies to the Attribute's QualifyIfPresent method.
func (a *evalAttr) QualifyIfPresent(ctx Activation, obj any, presenceOnly bool) (any, bool, error) {
return a.attr.QualifyIfPresent(ctx, obj, presenceOnly)
}
func (a *evalAttr) IsOptional() bool {
return a.optional
}
// Resolve proxies to the Attribute's Resolve method.
func (a *evalAttr) Resolve(ctx Activation) (interface{}, error) {
func (a *evalAttr) Resolve(ctx Activation) (any, error) {
return a.attr.Resolve(ctx)
}
type evalWatchConstructor struct {
constructor InterpretableConstructor
observer EvalObserver
}
// InitVals implements the InterpretableConstructor InitVals function.
func (c *evalWatchConstructor) InitVals() []Interpretable {
return c.constructor.InitVals()
}
// Type implements the InterpretableConstructor Type function.
func (c *evalWatchConstructor) Type() ref.Type {
return c.constructor.Type()
}
// ID implements the Interpretable ID function.
func (c *evalWatchConstructor) ID() int64 {
return c.constructor.ID()
}
// Eval implements the Interpretable Eval function.
func (c *evalWatchConstructor) Eval(ctx Activation) ref.Val {
val := c.constructor.Eval(ctx)
c.observer(c.ID(), c.constructor, val)
return val
}
func invalidOptionalEntryInit(field any, value ref.Val) ref.Val {
return types.NewErr("cannot initialize optional entry '%v' from non-optional value %v", field, value)
}
func invalidOptionalElementInit(value ref.Val) ref.Val {
return types.NewErr("cannot initialize optional list element from non-optional value %v", value)
}

View File

@ -29,19 +29,17 @@ import (
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(checked *exprpb.CheckedExpr,
decorators ...InterpretableDecorator) (Interpretable, error)
NewInterpretable(checked *exprpb.CheckedExpr, decorators ...InterpretableDecorator) (Interpretable, error)
// NewUncheckedInterpretable returns an Interpretable from a parsed expression
// and an optional list of InterpretableDecorator values.
NewUncheckedInterpretable(expr *exprpb.Expr,
decorators ...InterpretableDecorator) (Interpretable, error)
NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that
// was evaluated and value is the result of the evaluation.
type EvalObserver func(id int64, programStep interface{}, value ref.Val)
type EvalObserver func(id int64, programStep any, value ref.Val)
// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable
// or Qualifier during program evaluation.
@ -49,7 +47,7 @@ func Observe(observers ...EvalObserver) InterpretableDecorator {
if len(observers) == 1 {
return decObserveEval(observers[0])
}
observeFn := func(id int64, programStep interface{}, val ref.Val) {
observeFn := func(id int64, programStep any, val ref.Val) {
for _, observer := range observers {
observer(id, programStep, val)
}
@ -96,7 +94,7 @@ func TrackState(state EvalState) InterpretableDecorator {
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
func EvalStateObserver(state EvalState) EvalObserver {
return func(id int64, programStep interface{}, val ref.Val) {
return func(id int64, programStep any, val ref.Val) {
state.SetValue(id, val)
}
}

View File

@ -20,7 +20,6 @@ import (
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
@ -189,16 +188,7 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
if err != nil {
return nil, err
}
// Determine the field type if this is a proto message type.
var fieldType *ref.FieldType
opType := p.typeMap[sel.GetOperand().GetId()]
if opType.GetMessageType() != "" {
ft, found := p.provider.FindFieldType(opType.GetMessageType(), sel.GetField())
if found && ft.IsSet != nil && ft.GetFrom != nil {
fieldType = ft
}
}
// If the Select was marked TestOnly, this is a presence test.
//
@ -211,37 +201,31 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
// If a string named 'a.b.c' is declared in the environment and referenced within `has(a.b.c)`,
// it is not clear whether has should error or follow the convention defined for structured
// values.
if sel.TestOnly {
// Return the test only eval expression.
return &evalTestOnly{
id: expr.GetId(),
field: types.String(sel.GetField()),
fieldType: fieldType,
op: op,
}, nil
}
// Build a qualifier.
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), sel.GetField())
if err != nil {
return nil, err
}
// Lastly, create a field selection Interpretable.
// Establish the attribute reference.
attr, isAttr := op.(InterpretableAttribute)
if isAttr {
_, err = attr.AddQualifier(qual)
return attr, err
if !isAttr {
attr, err = p.relativeAttr(op.ID(), op, false)
if err != nil {
return nil, err
}
}
relAttr, err := p.relativeAttr(op.ID(), op)
// Build a qualifier for the attribute.
qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false)
if err != nil {
return nil, err
}
_, err = relAttr.AddQualifier(qual)
if err != nil {
return nil, err
// Modify the attribute to be test-only.
if sel.GetTestOnly() {
attr = &evalTestOnly{
id: expr.GetId(),
InterpretableAttribute: attr,
}
}
return relAttr, nil
// Append the qualifier on the attribute.
_, err = attr.AddQualifier(qual)
return attr, err
}
// planCall creates a callable Interpretable while specializing for common functions and invocation
@ -286,7 +270,9 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
case operators.NotEquals:
return p.planCallNotEqual(expr, args)
case operators.Index:
return p.planCallIndex(expr, args)
return p.planCallIndex(expr, args, false)
case operators.OptSelect, operators.OptIndex:
return p.planCallIndex(expr, args, true)
}
// Otherwise, generate Interpretable calls specialized by argument count.
@ -423,8 +409,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
}
// planCallEqual generates an equals (==) Interpretable.
func (p *planner) planCallEqual(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalEq{
id: expr.GetId(),
lhs: args[0],
@ -433,8 +418,7 @@ func (p *planner) planCallEqual(expr *exprpb.Expr,
}
// planCallNotEqual generates a not equals (!=) Interpretable.
func (p *planner) planCallNotEqual(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalNe{
id: expr.GetId(),
lhs: args[0],
@ -443,8 +427,7 @@ func (p *planner) planCallNotEqual(expr *exprpb.Expr,
}
// planCallLogicalAnd generates a logical and (&&) Interpretable.
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalAnd{
id: expr.GetId(),
lhs: args[0],
@ -453,8 +436,7 @@ func (p *planner) planCallLogicalAnd(expr *exprpb.Expr,
}
// planCallLogicalOr generates a logical or (||) Interpretable.
func (p *planner) planCallLogicalOr(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalOr{
id: expr.GetId(),
lhs: args[0],
@ -463,10 +445,8 @@ func (p *planner) planCallLogicalOr(expr *exprpb.Expr,
}
// planCallConditional generates a conditional / ternary (c ? t : f) Interpretable.
func (p *planner) planCallConditional(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
cond := args[0]
t := args[1]
var tAttr Attribute
truthyAttr, isTruthyAttr := t.(InterpretableAttribute)
@ -493,48 +473,54 @@ func (p *planner) planCallConditional(expr *exprpb.Expr,
// planCallIndex either extends an attribute with the argument to the index operation, or creates
// a relative attribute based on the return of a function call or operation.
func (p *planner) planCallIndex(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) {
op := args[0]
ind := args[1]
opAttr, err := p.relativeAttr(op.ID(), op)
if err != nil {
return nil, err
}
opType := p.typeMap[expr.GetCallExpr().GetTarget().GetId()]
indConst, isIndConst := ind.(InterpretableConst)
if isIndConst {
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), indConst.Value())
// Establish the attribute reference.
var err error
attr, isAttr := op.(InterpretableAttribute)
if !isAttr {
attr, err = p.relativeAttr(op.ID(), op, false)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(qual)
return opAttr, err
}
indAttr, isIndAttr := ind.(InterpretableAttribute)
if isIndAttr {
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), indAttr)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(qual)
return opAttr, err
// Construct the qualifier type.
var qual Qualifier
switch ind := ind.(type) {
case InterpretableConst:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional)
case InterpretableAttribute:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional)
default:
qual, err = p.relativeAttr(expr.GetId(), ind, optional)
}
indQual, err := p.relativeAttr(expr.GetId(), ind)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(indQual)
return opAttr, err
// Add the qualifier to the attribute
_, err = attr.AddQualifier(qual)
return attr, err
}
// planCreateList generates a list construction Interpretable.
func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
list := expr.GetListExpr()
elems := make([]Interpretable, len(list.GetElements()))
for i, elem := range list.GetElements() {
optionalIndices := list.GetOptionalIndices()
elements := list.GetElements()
optionals := make([]bool, len(elements))
for _, index := range optionalIndices {
if index < 0 || index >= int32(len(elements)) {
return nil, fmt.Errorf("optional index %d out of element bounds [0, %d]", index, len(elements))
}
optionals[index] = true
}
elems := make([]Interpretable, len(elements))
for i, elem := range elements {
elemVal, err := p.Plan(elem)
if err != nil {
return nil, err
@ -542,9 +528,11 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
elems[i] = elemVal
}
return &evalList{
id: expr.GetId(),
elems: elems,
adapter: p.adapter,
id: expr.GetId(),
elems: elems,
optionals: optionals,
hasOptionals: len(optionals) != 0,
adapter: p.adapter,
}, nil
}
@ -555,6 +543,7 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
return p.planCreateObj(expr)
}
entries := str.GetEntries()
optionals := make([]bool, len(entries))
keys := make([]Interpretable, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
@ -569,23 +558,27 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
return nil, err
}
vals[i] = valVal
optionals[i] = entry.GetOptionalEntry()
}
return &evalMap{
id: expr.GetId(),
keys: keys,
vals: vals,
adapter: p.adapter,
id: expr.GetId(),
keys: keys,
vals: vals,
optionals: optionals,
hasOptionals: len(optionals) != 0,
adapter: p.adapter,
}, nil
}
// planCreateObj generates an object construction Interpretable.
func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
obj := expr.GetStructExpr()
typeName, defined := p.resolveTypeName(obj.MessageName)
typeName, defined := p.resolveTypeName(obj.GetMessageName())
if !defined {
return nil, fmt.Errorf("unknown type: %s", typeName)
return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName())
}
entries := obj.GetEntries()
optionals := make([]bool, len(entries))
fields := make([]string, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
@ -595,13 +588,16 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
return nil, err
}
vals[i] = val
optionals[i] = entry.GetOptionalEntry()
}
return &evalObj{
id: expr.GetId(),
typeName: typeName,
fields: fields,
vals: vals,
provider: p.provider,
id: expr.GetId(),
typeName: typeName,
fields: fields,
vals: vals,
optionals: optionals,
hasOptionals: len(optionals) != 0,
provider: p.provider,
}, nil
}
@ -753,14 +749,18 @@ func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, stri
return target, fnName, ""
}
func (p *planner) relativeAttr(id int64, eval Interpretable) (InterpretableAttribute, error) {
// relativeAttr indicates that the attribute in this case acts as a qualifier and as such needs to
// be observed to ensure that it's evaluation value is properly recorded for state tracking.
func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) {
eAttr, ok := eval.(InterpretableAttribute)
if !ok {
eAttr = &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.RelativeAttribute(id, eval),
adapter: p.adapter,
attr: p.attrFactory.RelativeAttribute(id, eval),
optional: opt,
}
}
// This looks like it should either decorate the new evalAttr node, or early return the InterpretableAttribute
decAttr, err := p.decorate(eAttr, nil)
if err != nil {
return nil, err

View File

@ -16,6 +16,7 @@ package interpreter
import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -26,6 +27,7 @@ import (
type astPruner struct {
expr *exprpb.Expr
macroCalls map[int64]*exprpb.Expr
state EvalState
nextExprID int64
}
@ -65,13 +67,22 @@ type astPruner struct {
// compiled and constant folded expressions, but is not willing to constant
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, state EvalState) *exprpb.Expr {
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
expr: expr,
state: state,
nextExprID: 1}
newExpr, _ := pruner.prune(expr)
return newExpr
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expr: newExpr,
SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls},
}
}
func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
@ -84,28 +95,50 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
}
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
switch v := val.(type) {
case types.Bool:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
}
// Attempt to build a list literal.
@ -123,6 +156,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
}
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
@ -162,6 +196,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
@ -177,70 +212,147 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}
func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}
func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
// elem in list
call := node.GetCallExpr()
val, exists := p.maybeValue(call.GetArgs()[1].GetId())
if !exists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return nil, false
}
func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
arg := call.GetArgs()[0]
val, exists := p.maybeValue(arg.GetId())
if !exists {
return nil, false
}
if b, ok := val.(types.Bool); ok {
return p.maybeCreateLiteral(node.GetId(), !b)
}
return nil, false
}
func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if p.existsWithKnownValue(call.Args[0].GetId()) {
return call.Args[1], true
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
}
return call.GetArgs()[1], true
}
if p.existsWithKnownValue(call.Args[1].GetId()) {
return call.Args[0], true
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
}
return call.GetArgs()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return call.GetArgs()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return call.GetArgs()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
condVal, condValueExists := p.value(call.Args[0].GetId())
if !condValueExists || types.IsUnknownOrError(condVal) {
cond, exists := p.maybeValue(call.GetArgs()[0].GetId())
if !exists {
return nil, false
}
if condVal.Value().(bool) {
return call.Args[1], true
if cond.Value().(bool) {
return call.GetArgs()[1], true
}
return call.Args[2], true
return call.GetArgs()[2], true
}
func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if _, exists := p.value(node.GetId()); !exists {
return nil, false
}
call := node.GetCallExpr()
if call.Function == operators.LogicalOr || call.Function == operators.LogicalAnd {
return p.maybePruneAndOr(node)
if call.Function == operators.LogicalOr {
return p.maybePruneOr(node)
}
if call.Function == operators.LogicalAnd {
return p.maybePruneAnd(node)
}
if call.Function == operators.Conditional {
return p.maybePruneConditional(node)
}
if call.Function == operators.In {
return p.maybePruneIn(node)
}
if call.Function == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}
func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
return p.prune(node)
}
func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.value(node.GetId())
if valueExists && !types.IsUnknownOrError(val) {
val, valueExists := p.maybeValue(node.GetId())
if valueExists {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
delete(p.macroCalls, node.GetId())
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
}
}
// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
// more.
switch node.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
if operand, pruned := p.prune(node.GetSelectExpr().GetOperand()); pruned {
if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_SelectExpr{
@ -253,10 +365,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}, true
}
case *exprpb.Expr_CallExpr:
if newExpr, pruned := p.maybePruneFunction(node); pruned {
newExpr, _ = p.prune(newExpr)
return newExpr, true
}
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
@ -268,40 +376,75 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
for i, arg := range args {
newArgs[i] = arg
if newArg, prunedArg := p.prune(arg); prunedArg {
if newArg, prunedArg := p.maybePrune(arg); prunedArg {
prunedCall = true
newArgs[i] = newArg
}
}
if newTarget, prunedTarget := p.prune(call.GetTarget()); prunedTarget {
if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget {
prunedCall = true
newCall.Target = newTarget
}
newNode := &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}
if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
}
if prunedCall {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}, true
return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
newElems := make([]*exprpb.Expr, len(elems))
optIndices := node.GetListExpr().GetOptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool
prunedIdx := 0
for i, elem := range elems {
newElems[i] = elem
if newElem, prunedElem := p.prune(elem); prunedElem {
newElems[i] = newElem
prunedList = true
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
prunedList = true
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
}
continue
}
newOptIndexMap[int32(prunedIdx)] = true
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems = append(newElems, newElem)
prunedList = true
} else {
newElems = append(newElems, elem)
}
prunedIdx++
}
optIndices = make([]int32, len(newOptIndexMap))
idx := 0
for i := range newOptIndexMap {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
@ -313,8 +456,8 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
for i, entry := range entries {
newEntries[i] = entry
newKey, prunedKey := p.prune(entry.GetMapKey())
newValue, prunedValue := p.prune(entry.GetValue())
newKey, prunedKey := p.maybePrune(entry.GetMapKey())
newValue, prunedValue := p.maybePrune(entry.GetValue())
if !prunedKey && !prunedValue {
continue
}
@ -331,6 +474,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
newEntries[i] = newEntry
}
if prunedStruct {
@ -344,27 +488,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
},
}, true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.prune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
}
}
return node, false
}
@ -374,24 +497,82 @@ func (p *astPruner) value(id int64) (ref.Val, bool) {
return val, (found && val != nil)
}
func (p *astPruner) existsWithUnknownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && types.IsUnknown(val)
}
func (p *astPruner) existsWithKnownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && !types.IsUnknown(val)
func (p *astPruner) maybeValue(id int64) (ref.Val, bool) {
val, found := p.value(id)
if !found || types.IsUnknownOrError(val) {
return nil, false
}
return val, true
}
func (p *astPruner) nextID() int64 {
for {
_, found := p.state.Value(p.nextExprID)
if !found {
next := p.nextExprID
p.nextExprID++
return next
}
p.nextExprID++
next := p.nextExprID
p.nextExprID++
return next
}
type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
}
func getMaxID(expr *exprpb.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
}
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
}
}
func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
visitor.visitExpr(e)
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
visitor.visitEntry(entry)
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
}
}
}
}

View File

@ -36,7 +36,7 @@ type ActualCostEstimator interface {
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(tracker *CostTracker) EvalObserver {
observer := func(id int64, programStep interface{}, val ref.Val) {
observer := func(id int64, programStep any, val ref.Val) {
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
@ -53,6 +53,11 @@ func CostObserver(tracker *CostTracker) EvalObserver {
tracker.stack.drop(t.Attr().ID())
tracker.cost += common.SelectAndIdentCost
}
if !tracker.presenceTestHasCost {
if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly {
tracker.cost -= common.SelectAndIdentCost
}
}
case *evalExhaustiveConditional:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID())
@ -95,21 +100,58 @@ func CostObserver(tracker *CostTracker) EvalObserver {
return observer
}
// CostTracker represents the information needed for tacking runtime cost
// CostTrackerOption configures the behavior of CostTracker objects.
type CostTrackerOption func(*CostTracker) error
// CostTrackerLimit sets the runtime limit on the evaluation cost during execution and will terminate the expression
// evaluation if the limit is exceeded.
func CostTrackerLimit(limit uint64) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.Limit = &limit
return nil
}
}
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.presenceTestHasCost = hasCost
return nil
}
}
// NewCostTracker creates a new CostTracker with a given estimator and a set of functional CostTrackerOption values.
func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) {
tracker := &CostTracker{
Estimator: estimator,
presenceTestHasCost: true,
}
for _, opt := range opts {
err := opt(tracker)
if err != nil {
return nil, err
}
}
return tracker, nil
}
// CostTracker represents the information needed for tracking runtime cost.
type CostTracker struct {
Estimator ActualCostEstimator
Limit *uint64
Estimator ActualCostEstimator
Limit *uint64
presenceTestHasCost bool
cost uint64
stack refValStack
}
// ActualCost returns the runtime cost
func (c CostTracker) ActualCost() uint64 {
func (c *CostTracker) ActualCost() uint64 {
return c.cost
}
func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
var cost uint64
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
@ -122,7 +164,7 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
// if user has their own implementation of ActualCostEstimator, make sure to cover the mapping between overloadId and cost calculation
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
@ -179,7 +221,7 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
}
// actualSize returns the size of value
func (c CostTracker) actualSize(value ref.Val) uint64 {
func (c *CostTracker) actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}

View File

@ -23,8 +23,8 @@ go_library(
"//common/operators:go_default_library",
"//common/runes:go_default_library",
"//parser/gen:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],
@ -34,6 +34,7 @@ go_test(
name = "go_default_test",
size = "small",
srcs = [
"helper_test.go",
"parser_test.go",
"unescape_test.go",
"unparser_test.go",
@ -45,7 +46,8 @@ go_test(
"//common/debug:go_default_library",
"//parser/gen:go_default_library",
"//test:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//testing/protocmp:go_default_library",
],
)

View File

@ -21,6 +21,6 @@ go_library(
],
importpath = "github.com/google/cel-go/parser/gen",
deps = [
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
],
)

View File

@ -52,16 +52,18 @@ unary
member
: primary # PrimaryExpr
| member op='.' id=IDENTIFIER (open='(' args=exprList? ')')? # SelectOrCall
| member op='[' index=expr ']' # Index
| member op='{' entries=fieldInitializerList? ','? '}' # CreateMessage
| member op='.' (opt='?')? id=IDENTIFIER # Select
| member op='.' id=IDENTIFIER open='(' args=exprList? ')' # MemberCall
| member op='[' (opt='?')? index=expr ']' # Index
;
primary
: leadingDot='.'? id=IDENTIFIER (op='(' args=exprList? ')')? # IdentOrGlobalCall
| '(' e=expr ')' # Nested
| op='[' elems=exprList? ','? ']' # CreateList
| op='[' elems=listInit? ','? ']' # CreateList
| op='{' entries=mapInitializerList? ','? '}' # CreateStruct
| leadingDot='.'? ids+=IDENTIFIER (ops+='.' ids+=IDENTIFIER)*
op='{' entries=fieldInitializerList? ','? '}' # CreateMessage
| literal # ConstantLiteral
;
@ -69,23 +71,35 @@ exprList
: e+=expr (',' e+=expr)*
;
listInit
: elems+=optExpr (',' elems+=optExpr)*
;
fieldInitializerList
: fields+=IDENTIFIER cols+=':' values+=expr (',' fields+=IDENTIFIER cols+=':' values+=expr)*
: fields+=optField cols+=':' values+=expr (',' fields+=optField cols+=':' values+=expr)*
;
optField
: (opt='?')? IDENTIFIER
;
mapInitializerList
: keys+=expr cols+=':' values+=expr (',' keys+=expr cols+=':' values+=expr)*
: keys+=optExpr cols+=':' values+=expr (',' keys+=optExpr cols+=':' values+=expr)*
;
optExpr
: (opt='?')? e=expr
;
literal
: sign=MINUS? tok=NUM_INT # Int
| tok=NUM_UINT # Uint
| tok=NUM_UINT # Uint
| sign=MINUS? tok=NUM_FLOAT # Double
| tok=STRING # String
| tok=BYTES # Bytes
| tok=CEL_TRUE # BoolTrue
| tok=CEL_FALSE # BoolFalse
| tok=NUL # Null
| tok=STRING # String
| tok=BYTES # Bytes
| tok=CEL_TRUE # BoolTrue
| tok=CEL_FALSE # BoolFalse
| tok=NUL # Null
;
// Lexer Rules

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,7 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.10.1. DO NOT EDIT.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
// BaseCELListener is a complete listener for a parse tree produced by CELParser.
type BaseCELListener struct{}
@ -74,11 +74,17 @@ func (s *BaseCELListener) EnterNegate(ctx *NegateContext) {}
// ExitNegate is called when production Negate is exited.
func (s *BaseCELListener) ExitNegate(ctx *NegateContext) {}
// EnterSelectOrCall is called when production SelectOrCall is entered.
func (s *BaseCELListener) EnterSelectOrCall(ctx *SelectOrCallContext) {}
// EnterMemberCall is called when production MemberCall is entered.
func (s *BaseCELListener) EnterMemberCall(ctx *MemberCallContext) {}
// ExitSelectOrCall is called when production SelectOrCall is exited.
func (s *BaseCELListener) ExitSelectOrCall(ctx *SelectOrCallContext) {}
// ExitMemberCall is called when production MemberCall is exited.
func (s *BaseCELListener) ExitMemberCall(ctx *MemberCallContext) {}
// EnterSelect is called when production Select is entered.
func (s *BaseCELListener) EnterSelect(ctx *SelectContext) {}
// ExitSelect is called when production Select is exited.
func (s *BaseCELListener) ExitSelect(ctx *SelectContext) {}
// EnterPrimaryExpr is called when production PrimaryExpr is entered.
func (s *BaseCELListener) EnterPrimaryExpr(ctx *PrimaryExprContext) {}
@ -92,12 +98,6 @@ func (s *BaseCELListener) EnterIndex(ctx *IndexContext) {}
// ExitIndex is called when production Index is exited.
func (s *BaseCELListener) ExitIndex(ctx *IndexContext) {}
// EnterCreateMessage is called when production CreateMessage is entered.
func (s *BaseCELListener) EnterCreateMessage(ctx *CreateMessageContext) {}
// ExitCreateMessage is called when production CreateMessage is exited.
func (s *BaseCELListener) ExitCreateMessage(ctx *CreateMessageContext) {}
// EnterIdentOrGlobalCall is called when production IdentOrGlobalCall is entered.
func (s *BaseCELListener) EnterIdentOrGlobalCall(ctx *IdentOrGlobalCallContext) {}
@ -122,6 +122,12 @@ func (s *BaseCELListener) EnterCreateStruct(ctx *CreateStructContext) {}
// ExitCreateStruct is called when production CreateStruct is exited.
func (s *BaseCELListener) ExitCreateStruct(ctx *CreateStructContext) {}
// EnterCreateMessage is called when production CreateMessage is entered.
func (s *BaseCELListener) EnterCreateMessage(ctx *CreateMessageContext) {}
// ExitCreateMessage is called when production CreateMessage is exited.
func (s *BaseCELListener) ExitCreateMessage(ctx *CreateMessageContext) {}
// EnterConstantLiteral is called when production ConstantLiteral is entered.
func (s *BaseCELListener) EnterConstantLiteral(ctx *ConstantLiteralContext) {}
@ -134,18 +140,36 @@ func (s *BaseCELListener) EnterExprList(ctx *ExprListContext) {}
// ExitExprList is called when production exprList is exited.
func (s *BaseCELListener) ExitExprList(ctx *ExprListContext) {}
// EnterListInit is called when production listInit is entered.
func (s *BaseCELListener) EnterListInit(ctx *ListInitContext) {}
// ExitListInit is called when production listInit is exited.
func (s *BaseCELListener) ExitListInit(ctx *ListInitContext) {}
// EnterFieldInitializerList is called when production fieldInitializerList is entered.
func (s *BaseCELListener) EnterFieldInitializerList(ctx *FieldInitializerListContext) {}
// ExitFieldInitializerList is called when production fieldInitializerList is exited.
func (s *BaseCELListener) ExitFieldInitializerList(ctx *FieldInitializerListContext) {}
// EnterOptField is called when production optField is entered.
func (s *BaseCELListener) EnterOptField(ctx *OptFieldContext) {}
// ExitOptField is called when production optField is exited.
func (s *BaseCELListener) ExitOptField(ctx *OptFieldContext) {}
// EnterMapInitializerList is called when production mapInitializerList is entered.
func (s *BaseCELListener) EnterMapInitializerList(ctx *MapInitializerListContext) {}
// ExitMapInitializerList is called when production mapInitializerList is exited.
func (s *BaseCELListener) ExitMapInitializerList(ctx *MapInitializerListContext) {}
// EnterOptExpr is called when production optExpr is entered.
func (s *BaseCELListener) EnterOptExpr(ctx *OptExprContext) {}
// ExitOptExpr is called when production optExpr is exited.
func (s *BaseCELListener) ExitOptExpr(ctx *OptExprContext) {}
// EnterInt is called when production Int is entered.
func (s *BaseCELListener) EnterInt(ctx *IntContext) {}

View File

@ -1,7 +1,7 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.10.1. DO NOT EDIT.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
type BaseCELVisitor struct {
*antlr.BaseParseTreeVisitor
@ -43,7 +43,11 @@ func (v *BaseCELVisitor) VisitNegate(ctx *NegateContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitSelectOrCall(ctx *SelectOrCallContext) interface{} {
func (v *BaseCELVisitor) VisitMemberCall(ctx *MemberCallContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitSelect(ctx *SelectContext) interface{} {
return v.VisitChildren(ctx)
}
@ -55,10 +59,6 @@ func (v *BaseCELVisitor) VisitIndex(ctx *IndexContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitCreateMessage(ctx *CreateMessageContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitIdentOrGlobalCall(ctx *IdentOrGlobalCallContext) interface{} {
return v.VisitChildren(ctx)
}
@ -75,6 +75,10 @@ func (v *BaseCELVisitor) VisitCreateStruct(ctx *CreateStructContext) interface{}
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitCreateMessage(ctx *CreateMessageContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitConstantLiteral(ctx *ConstantLiteralContext) interface{} {
return v.VisitChildren(ctx)
}
@ -83,14 +87,26 @@ func (v *BaseCELVisitor) VisitExprList(ctx *ExprListContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitListInit(ctx *ListInitContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitFieldInitializerList(ctx *FieldInitializerListContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitOptField(ctx *OptFieldContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitMapInitializerList(ctx *MapInitializerListContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitOptExpr(ctx *OptExprContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseCELVisitor) VisitInt(ctx *IntContext) interface{} {
return v.VisitChildren(ctx)
}

View File

@ -1,4 +1,4 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.10.1. DO NOT EDIT.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
package gen
@ -7,7 +7,7 @@ import (
"sync"
"unicode"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
)
// Suppress unused import error

View File

@ -1,7 +1,7 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.10.1. DO NOT EDIT.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
// CELListener is a complete listener for a parse tree produced by CELParser.
type CELListener interface {
@ -34,8 +34,11 @@ type CELListener interface {
// EnterNegate is called when entering the Negate production.
EnterNegate(c *NegateContext)
// EnterSelectOrCall is called when entering the SelectOrCall production.
EnterSelectOrCall(c *SelectOrCallContext)
// EnterMemberCall is called when entering the MemberCall production.
EnterMemberCall(c *MemberCallContext)
// EnterSelect is called when entering the Select production.
EnterSelect(c *SelectContext)
// EnterPrimaryExpr is called when entering the PrimaryExpr production.
EnterPrimaryExpr(c *PrimaryExprContext)
@ -43,9 +46,6 @@ type CELListener interface {
// EnterIndex is called when entering the Index production.
EnterIndex(c *IndexContext)
// EnterCreateMessage is called when entering the CreateMessage production.
EnterCreateMessage(c *CreateMessageContext)
// EnterIdentOrGlobalCall is called when entering the IdentOrGlobalCall production.
EnterIdentOrGlobalCall(c *IdentOrGlobalCallContext)
@ -58,18 +58,30 @@ type CELListener interface {
// EnterCreateStruct is called when entering the CreateStruct production.
EnterCreateStruct(c *CreateStructContext)
// EnterCreateMessage is called when entering the CreateMessage production.
EnterCreateMessage(c *CreateMessageContext)
// EnterConstantLiteral is called when entering the ConstantLiteral production.
EnterConstantLiteral(c *ConstantLiteralContext)
// EnterExprList is called when entering the exprList production.
EnterExprList(c *ExprListContext)
// EnterListInit is called when entering the listInit production.
EnterListInit(c *ListInitContext)
// EnterFieldInitializerList is called when entering the fieldInitializerList production.
EnterFieldInitializerList(c *FieldInitializerListContext)
// EnterOptField is called when entering the optField production.
EnterOptField(c *OptFieldContext)
// EnterMapInitializerList is called when entering the mapInitializerList production.
EnterMapInitializerList(c *MapInitializerListContext)
// EnterOptExpr is called when entering the optExpr production.
EnterOptExpr(c *OptExprContext)
// EnterInt is called when entering the Int production.
EnterInt(c *IntContext)
@ -121,8 +133,11 @@ type CELListener interface {
// ExitNegate is called when exiting the Negate production.
ExitNegate(c *NegateContext)
// ExitSelectOrCall is called when exiting the SelectOrCall production.
ExitSelectOrCall(c *SelectOrCallContext)
// ExitMemberCall is called when exiting the MemberCall production.
ExitMemberCall(c *MemberCallContext)
// ExitSelect is called when exiting the Select production.
ExitSelect(c *SelectContext)
// ExitPrimaryExpr is called when exiting the PrimaryExpr production.
ExitPrimaryExpr(c *PrimaryExprContext)
@ -130,9 +145,6 @@ type CELListener interface {
// ExitIndex is called when exiting the Index production.
ExitIndex(c *IndexContext)
// ExitCreateMessage is called when exiting the CreateMessage production.
ExitCreateMessage(c *CreateMessageContext)
// ExitIdentOrGlobalCall is called when exiting the IdentOrGlobalCall production.
ExitIdentOrGlobalCall(c *IdentOrGlobalCallContext)
@ -145,18 +157,30 @@ type CELListener interface {
// ExitCreateStruct is called when exiting the CreateStruct production.
ExitCreateStruct(c *CreateStructContext)
// ExitCreateMessage is called when exiting the CreateMessage production.
ExitCreateMessage(c *CreateMessageContext)
// ExitConstantLiteral is called when exiting the ConstantLiteral production.
ExitConstantLiteral(c *ConstantLiteralContext)
// ExitExprList is called when exiting the exprList production.
ExitExprList(c *ExprListContext)
// ExitListInit is called when exiting the listInit production.
ExitListInit(c *ListInitContext)
// ExitFieldInitializerList is called when exiting the fieldInitializerList production.
ExitFieldInitializerList(c *FieldInitializerListContext)
// ExitOptField is called when exiting the optField production.
ExitOptField(c *OptFieldContext)
// ExitMapInitializerList is called when exiting the mapInitializerList production.
ExitMapInitializerList(c *MapInitializerListContext)
// ExitOptExpr is called when exiting the optExpr production.
ExitOptExpr(c *OptExprContext)
// ExitInt is called when exiting the Int production.
ExitInt(c *IntContext)

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.10.1. DO NOT EDIT.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
// A complete Visitor for a parse tree produced by CELParser.
type CELVisitor interface {
@ -34,8 +34,11 @@ type CELVisitor interface {
// Visit a parse tree produced by CELParser#Negate.
VisitNegate(ctx *NegateContext) interface{}
// Visit a parse tree produced by CELParser#SelectOrCall.
VisitSelectOrCall(ctx *SelectOrCallContext) interface{}
// Visit a parse tree produced by CELParser#MemberCall.
VisitMemberCall(ctx *MemberCallContext) interface{}
// Visit a parse tree produced by CELParser#Select.
VisitSelect(ctx *SelectContext) interface{}
// Visit a parse tree produced by CELParser#PrimaryExpr.
VisitPrimaryExpr(ctx *PrimaryExprContext) interface{}
@ -43,9 +46,6 @@ type CELVisitor interface {
// Visit a parse tree produced by CELParser#Index.
VisitIndex(ctx *IndexContext) interface{}
// Visit a parse tree produced by CELParser#CreateMessage.
VisitCreateMessage(ctx *CreateMessageContext) interface{}
// Visit a parse tree produced by CELParser#IdentOrGlobalCall.
VisitIdentOrGlobalCall(ctx *IdentOrGlobalCallContext) interface{}
@ -58,18 +58,30 @@ type CELVisitor interface {
// Visit a parse tree produced by CELParser#CreateStruct.
VisitCreateStruct(ctx *CreateStructContext) interface{}
// Visit a parse tree produced by CELParser#CreateMessage.
VisitCreateMessage(ctx *CreateMessageContext) interface{}
// Visit a parse tree produced by CELParser#ConstantLiteral.
VisitConstantLiteral(ctx *ConstantLiteralContext) interface{}
// Visit a parse tree produced by CELParser#exprList.
VisitExprList(ctx *ExprListContext) interface{}
// Visit a parse tree produced by CELParser#listInit.
VisitListInit(ctx *ListInitContext) interface{}
// Visit a parse tree produced by CELParser#fieldInitializerList.
VisitFieldInitializerList(ctx *FieldInitializerListContext) interface{}
// Visit a parse tree produced by CELParser#optField.
VisitOptField(ctx *OptFieldContext) interface{}
// Visit a parse tree produced by CELParser#mapInitializerList.
VisitMapInitializerList(ctx *MapInitializerListContext) interface{}
// Visit a parse tree produced by CELParser#optExpr.
VisitOptExpr(ctx *OptExprContext) interface{}
// Visit a parse tree produced by CELParser#Int.
VisitInt(ctx *IntContext) interface{}

View File

@ -27,7 +27,7 @@
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Generate AntLR artifacts.
java -Xmx500M -cp ${DIR}/antlr-4.10.1-complete.jar org.antlr.v4.Tool \
java -Xmx500M -cp ${DIR}/antlr-4.12.0-complete.jar org.antlr.v4.Tool \
-Dlanguage=Go \
-package gen \
-o ${DIR} \

View File

@ -17,7 +17,8 @@ package parser
import (
"sync"
"github.com/antlr/antlr4/runtime/Go/antlr"
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -47,115 +48,122 @@ func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo {
MacroCalls: p.macroCalls}
}
func (p *parserHelper) newLiteral(ctx interface{}, value *exprpb.Constant) *exprpb.Expr {
func (p *parserHelper) newLiteral(ctx any, value *exprpb.Constant) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: value}
return exprNode
}
func (p *parserHelper) newLiteralBool(ctx interface{}, value bool) *exprpb.Expr {
func (p *parserHelper) newLiteralBool(ctx any, value bool) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: value}})
}
func (p *parserHelper) newLiteralString(ctx interface{}, value string) *exprpb.Expr {
func (p *parserHelper) newLiteralString(ctx any, value string) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: value}})
}
func (p *parserHelper) newLiteralBytes(ctx interface{}, value []byte) *exprpb.Expr {
func (p *parserHelper) newLiteralBytes(ctx any, value []byte) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: value}})
}
func (p *parserHelper) newLiteralInt(ctx interface{}, value int64) *exprpb.Expr {
func (p *parserHelper) newLiteralInt(ctx any, value int64) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: value}})
}
func (p *parserHelper) newLiteralUint(ctx interface{}, value uint64) *exprpb.Expr {
func (p *parserHelper) newLiteralUint(ctx any, value uint64) *exprpb.Expr {
return p.newLiteral(ctx, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: value}})
}
func (p *parserHelper) newLiteralDouble(ctx interface{}, value float64) *exprpb.Expr {
func (p *parserHelper) newLiteralDouble(ctx any, value float64) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: value}})
}
func (p *parserHelper) newIdent(ctx interface{}, name string) *exprpb.Expr {
func (p *parserHelper) newIdent(ctx any, name string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: &exprpb.Expr_Ident{Name: name}}
return exprNode
}
func (p *parserHelper) newSelect(ctx interface{}, operand *exprpb.Expr, field string) *exprpb.Expr {
func (p *parserHelper) newSelect(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field}}
return exprNode
}
func (p *parserHelper) newPresenceTest(ctx interface{}, operand *exprpb.Expr, field string) *exprpb.Expr {
func (p *parserHelper) newPresenceTest(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field, TestOnly: true}}
return exprNode
}
func (p *parserHelper) newGlobalCall(ctx interface{}, function string, args ...*exprpb.Expr) *exprpb.Expr {
func (p *parserHelper) newGlobalCall(ctx any, function string, args ...*exprpb.Expr) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{Function: function, Args: args}}
return exprNode
}
func (p *parserHelper) newReceiverCall(ctx interface{}, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
func (p *parserHelper) newReceiverCall(ctx any, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{Function: function, Target: target, Args: args}}
return exprNode
}
func (p *parserHelper) newList(ctx interface{}, elements ...*exprpb.Expr) *exprpb.Expr {
func (p *parserHelper) newList(ctx any, elements []*exprpb.Expr, optionals ...int32) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{Elements: elements}}
ListExpr: &exprpb.Expr_CreateList{
Elements: elements,
OptionalIndices: optionals,
}}
return exprNode
}
func (p *parserHelper) newMap(ctx interface{}, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
func (p *parserHelper) newMap(ctx any, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{Entries: entries}}
return exprNode
}
func (p *parserHelper) newMapEntry(entryID int64, key *exprpb.Expr, value *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry {
func (p *parserHelper) newMapEntry(entryID int64, key *exprpb.Expr, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return &exprpb.Expr_CreateStruct_Entry{
Id: entryID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{MapKey: key},
Value: value}
Id: entryID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{MapKey: key},
Value: value,
OptionalEntry: optional,
}
}
func (p *parserHelper) newObject(ctx interface{},
typeName string,
entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
func (p *parserHelper) newObject(ctx any, typeName string, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: typeName,
Entries: entries}}
Entries: entries,
},
}
return exprNode
}
func (p *parserHelper) newObjectField(fieldID int64, field string, value *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry {
func (p *parserHelper) newObjectField(fieldID int64, field string, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return &exprpb.Expr_CreateStruct_Entry{
Id: fieldID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: field},
Value: value}
Id: fieldID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: field},
Value: value,
OptionalEntry: optional,
}
}
func (p *parserHelper) newComprehension(ctx interface{}, iterVar string,
func (p *parserHelper) newComprehension(ctx any, iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
@ -175,7 +183,7 @@ func (p *parserHelper) newComprehension(ctx interface{}, iterVar string,
return exprNode
}
func (p *parserHelper) newExpr(ctx interface{}) *exprpb.Expr {
func (p *parserHelper) newExpr(ctx any) *exprpb.Expr {
id, isID := ctx.(int64)
if isID {
return &exprpb.Expr{Id: id}
@ -183,7 +191,7 @@ func (p *parserHelper) newExpr(ctx interface{}) *exprpb.Expr {
return &exprpb.Expr{Id: p.id(ctx)}
}
func (p *parserHelper) id(ctx interface{}) int64 {
func (p *parserHelper) id(ctx any) int64 {
var location common.Location
switch ctx.(type) {
case antlr.ParserRuleContext:
@ -251,7 +259,8 @@ func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr {
Id: expr.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: macroListArgs,
Elements: macroListArgs,
OptionalIndices: listExpr.GetOptionalIndices(),
},
},
}
@ -360,6 +369,95 @@ func (e *exprHelper) nextMacroID() int64 {
return e.parserHelper.id(e.parserHelper.getLocation(e.id))
}
// Copy implements the ExprHelper interface method by producing a copy of the input Expr value
// with a fresh set of numeric identifiers the Expr and all its descendents.
func (e *exprHelper) Copy(expr *exprpb.Expr) *exprpb.Expr {
copy := e.parserHelper.newExpr(e.parserHelper.getLocation(expr.GetId()))
switch expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
copy.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: expr.GetConstExpr()}
case *exprpb.Expr_IdentExpr:
copy.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: expr.GetIdentExpr()}
case *exprpb.Expr_SelectExpr:
op := expr.GetSelectExpr().GetOperand()
copy.ExprKind = &exprpb.Expr_SelectExpr{SelectExpr: &exprpb.Expr_Select{
Operand: e.Copy(op),
Field: expr.GetSelectExpr().GetField(),
TestOnly: expr.GetSelectExpr().GetTestOnly(),
}}
case *exprpb.Expr_CallExpr:
call := expr.GetCallExpr()
target := call.GetTarget()
if target != nil {
target = e.Copy(target)
}
args := call.GetArgs()
argsCopy := make([]*exprpb.Expr, len(args))
for i, arg := range args {
argsCopy[i] = e.Copy(arg)
}
copy.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: call.GetFunction(),
Target: target,
Args: argsCopy,
},
}
case *exprpb.Expr_ListExpr:
elems := expr.GetListExpr().GetElements()
elemsCopy := make([]*exprpb.Expr, len(elems))
for i, elem := range elems {
elemsCopy[i] = e.Copy(elem)
}
copy.ExprKind = &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{Elements: elemsCopy},
}
case *exprpb.Expr_StructExpr:
entries := expr.GetStructExpr().GetEntries()
entriesCopy := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
for i, entry := range entries {
entryCopy := &exprpb.Expr_CreateStruct_Entry{}
entryCopy.Id = e.nextMacroID()
switch entry.GetKeyKind().(type) {
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: entry.GetFieldKey(),
}
case *exprpb.Expr_CreateStruct_Entry_MapKey:
entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: e.Copy(entry.GetMapKey()),
}
}
entryCopy.Value = e.Copy(entry.GetValue())
entriesCopy[i] = entryCopy
}
copy.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: expr.GetStructExpr().GetMessageName(),
Entries: entriesCopy,
},
}
case *exprpb.Expr_ComprehensionExpr:
iterRange := e.Copy(expr.GetComprehensionExpr().GetIterRange())
accuInit := e.Copy(expr.GetComprehensionExpr().GetAccuInit())
cond := e.Copy(expr.GetComprehensionExpr().GetLoopCondition())
step := e.Copy(expr.GetComprehensionExpr().GetLoopStep())
result := e.Copy(expr.GetComprehensionExpr().GetResult())
copy.ExprKind = &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterRange: iterRange,
IterVar: expr.GetComprehensionExpr().GetIterVar(),
AccuInit: accuInit,
AccuVar: expr.GetComprehensionExpr().GetAccuVar(),
LoopCondition: cond,
LoopStep: step,
Result: result,
},
}
}
return copy
}
// LiteralBool implements the ExprHelper interface method.
func (e *exprHelper) LiteralBool(value bool) *exprpb.Expr {
return e.parserHelper.newLiteralBool(e.nextMacroID(), value)
@ -392,7 +490,7 @@ func (e *exprHelper) LiteralUint(value uint64) *exprpb.Expr {
// NewList implements the ExprHelper interface method.
func (e *exprHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr {
return e.parserHelper.newList(e.nextMacroID(), elems...)
return e.parserHelper.newList(e.nextMacroID(), elems)
}
// NewMap implements the ExprHelper interface method.
@ -401,21 +499,18 @@ func (e *exprHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.
}
// NewMapEntry implements the ExprHelper interface method.
func (e *exprHelper) NewMapEntry(key *exprpb.Expr,
val *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newMapEntry(e.nextMacroID(), key, val)
func (e *exprHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newMapEntry(e.nextMacroID(), key, val, optional)
}
// NewObject implements the ExprHelper interface method.
func (e *exprHelper) NewObject(typeName string,
fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
func (e *exprHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
return e.parserHelper.newObject(e.nextMacroID(), typeName, fieldInits...)
}
// NewObjectFieldInit implements the ExprHelper interface method.
func (e *exprHelper) NewObjectFieldInit(field string,
init *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newObjectField(e.nextMacroID(), field, init)
func (e *exprHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newObjectField(e.nextMacroID(), field, init, optional)
}
// Fold implements the ExprHelper interface method.
@ -471,7 +566,7 @@ func (e *exprHelper) OffsetLocation(exprID int64) common.Location {
var (
// Thread-safe pool of ExprHelper values to minimize alloc overhead of ExprHelper creations.
exprHelperPool = &sync.Pool{
New: func() interface{} {
New: func() any {
return &exprHelper{}
},
}

View File

@ -15,7 +15,8 @@
package parser
import (
"github.com/antlr/antlr4/runtime/Go/antlr"
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/google/cel-go/common/runes"
)

View File

@ -132,8 +132,11 @@ func makeVarArgMacroKey(name string, receiverStyle bool) string {
return fmt.Sprintf("%s:*:%v", name, receiverStyle)
}
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree, or an error
// if the input arguments are not suitable for the expansion requirements for the macro in question.
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
// are not well-formed, the result of the expansion will be an error.
//
// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call
// and produces as output an Expr ast node.
@ -147,6 +150,9 @@ type MacroExpander func(eh ExprHelper,
// consistent with the source position and expression id generation code leveraged by both
// the parser and type-checker.
type ExprHelper interface {
// Copy the input expression with a brand new set of identifiers.
Copy(*exprpb.Expr) *exprpb.Expr
// LiteralBool creates an Expr value for a bool literal.
LiteralBool(value bool) *exprpb.Expr
@ -174,14 +180,14 @@ type ExprHelper interface {
NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewMapEntry creates a Map Entry for the key, value pair.
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
NewObjectFieldInit(field string, init *exprpb.Expr) *exprpb.Expr_CreateStruct_Entry
NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// Fold creates a fold comprehension instruction.
//
@ -309,8 +315,10 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex
// input to produce an output list.
//
// There are two call patterns supported by map:
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {

View File

@ -18,11 +18,13 @@ import "fmt"
type options struct {
maxRecursionDepth int
errorReportingLimit int
errorRecoveryTokenLookaheadLimit int
errorRecoveryLimit int
expressionSizeCodePointLimit int
macros map[string]Macro
populateMacroCalls bool
enableOptionalSyntax bool
}
// Option configures the behavior of the parser.
@ -45,7 +47,7 @@ func MaxRecursionDepth(limit int) Option {
// successfully resume. In some pathological cases, the parser can look through quite a large set of input which
// in turn generates a lot of back-tracking and performance degredation.
//
// The limit must be > 1, and is recommended to be less than the default of 256.
// The limit must be >= 1, and is recommended to be less than the default of 256.
func ErrorRecoveryLookaheadTokenLimit(limit int) Option {
return func(opts *options) error {
if limit < 1 {
@ -67,6 +69,19 @@ func ErrorRecoveryLimit(limit int) Option {
}
}
// ErrorReportingLimit limits the number of syntax error reports before terminating parsing.
//
// The limit must be at least 1. If unset, the limit will be 100.
func ErrorReportingLimit(limit int) Option {
return func(opts *options) error {
if limit < 1 {
return fmt.Errorf("error reporting limit must be at least 1: %d", limit)
}
opts.errorReportingLimit = limit
return nil
}
}
// ExpressionSizeCodePointLimit is an option which limits the maximum code point count of an
// expression.
func ExpressionSizeCodePointLimit(expressionSizeCodePointLimit int) Option {
@ -102,3 +117,11 @@ func PopulateMacroCalls(populateMacroCalls bool) Option {
return nil
}
}
// EnableOptionalSyntax enables syntax for optional field and index selection.
func EnableOptionalSyntax(optionalSyntax bool) Option {
return func(opts *options) error {
opts.enableOptionalSyntax = optionalSyntax
return nil
}
}

View File

@ -18,11 +18,13 @@ package parser
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"github.com/antlr/antlr4/runtime/Go/antlr"
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/runes"
@ -45,6 +47,9 @@ func NewParser(opts ...Option) (*Parser, error) {
return nil, err
}
}
if p.errorReportingLimit == 0 {
p.errorReportingLimit = 100
}
if p.maxRecursionDepth == 0 {
p.maxRecursionDepth = 250
}
@ -89,9 +94,11 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors
helper: newParserHelper(source),
macros: p.macros,
maxRecursionDepth: p.maxRecursionDepth,
errorReportingLimit: p.errorReportingLimit,
errorRecoveryLimit: p.errorRecoveryLimit,
errorRecoveryLookaheadTokenLimit: p.errorRecoveryTokenLookaheadLimit,
populateMacroCalls: p.populateMacroCalls,
enableOptionalSyntax: p.enableOptionalSyntax,
}
buf, ok := source.(runes.Buffer)
if !ok {
@ -178,7 +185,7 @@ func (rl *recursionListener) EnterEveryRule(ctx antlr.ParserRuleContext) {
} else {
*depth++
}
if *depth >= rl.maxDepth {
if *depth > rl.maxDepth {
panic(&recursionError{
message: fmt.Sprintf("expression recursion limit exceeded: %d", rl.maxDepth),
})
@ -197,6 +204,16 @@ func (rl *recursionListener) ExitEveryRule(ctx antlr.ParserRuleContext) {
var _ antlr.ParseTreeListener = &recursionListener{}
type tooManyErrors struct {
errorReportingLimit int
}
func (t *tooManyErrors) Error() string {
return fmt.Sprintf("More than %d syntax errors", t.errorReportingLimit)
}
var _ error = &tooManyErrors{}
type recoveryLimitError struct {
message string
}
@ -271,17 +288,20 @@ type parser struct {
helper *parserHelper
macros map[string]Macro
recursionDepth int
errorReports int
maxRecursionDepth int
errorReportingLimit int
errorRecoveryLimit int
errorRecoveryLookaheadTokenLimit int
populateMacroCalls bool
enableOptionalSyntax bool
}
var (
_ gen.CELVisitor = (*parser)(nil)
lexerPool *sync.Pool = &sync.Pool{
New: func() interface{} {
New: func() any {
l := gen.NewCELLexer(nil)
l.RemoveErrorListeners()
return l
@ -289,7 +309,7 @@ var (
}
parserPool *sync.Pool = &sync.Pool{
New: func() interface{} {
New: func() any {
p := gen.NewCELParser(nil)
p.RemoveErrorListeners()
return p
@ -302,14 +322,14 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
lexer := lexerPool.Get().(*gen.CELLexer)
prsr := parserPool.Get().(*gen.CELParser)
// Unfortunately ANTLR Go runtime is missing (*antlr.BaseParser).RemoveParseListeners, so this is
// good enough until that is exported.
prsrListener := &recursionListener{
maxDepth: p.maxRecursionDepth,
ruleTypeDepth: map[int]*int{},
}
defer func() {
// Unfortunately ANTLR Go runtime is missing (*antlr.BaseParser).RemoveParseListeners,
// so this is good enough until that is exported.
// Reset the lexer and parser before putting them back in the pool.
lexer.RemoveErrorListeners()
prsr.RemoveParseListener(prsrListener)
@ -340,6 +360,8 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
p.errors.ReportError(common.NoLocation, err.Error())
case *recursionError:
p.errors.ReportError(common.NoLocation, err.Error())
case *tooManyErrors:
// do nothing
case *recoveryLimitError:
// do nothing, listeners already notified and error reported.
default:
@ -352,57 +374,85 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
}
// Visitor implementations.
func (p *parser) Visit(tree antlr.ParseTree) interface{} {
p.recursionDepth++
if p.recursionDepth > p.maxRecursionDepth {
panic(&recursionError{message: "max recursion depth exceeded"})
}
defer func() {
p.recursionDepth--
}()
switch tree.(type) {
func (p *parser) Visit(tree antlr.ParseTree) any {
t := unnest(tree)
switch tree := t.(type) {
case *gen.StartContext:
return p.VisitStart(tree.(*gen.StartContext))
return p.VisitStart(tree)
case *gen.ExprContext:
return p.VisitExpr(tree.(*gen.ExprContext))
p.checkAndIncrementRecursionDepth()
out := p.VisitExpr(tree)
p.decrementRecursionDepth()
return out
case *gen.ConditionalAndContext:
return p.VisitConditionalAnd(tree.(*gen.ConditionalAndContext))
return p.VisitConditionalAnd(tree)
case *gen.ConditionalOrContext:
return p.VisitConditionalOr(tree.(*gen.ConditionalOrContext))
return p.VisitConditionalOr(tree)
case *gen.RelationContext:
return p.VisitRelation(tree.(*gen.RelationContext))
p.checkAndIncrementRecursionDepth()
out := p.VisitRelation(tree)
p.decrementRecursionDepth()
return out
case *gen.CalcContext:
return p.VisitCalc(tree.(*gen.CalcContext))
p.checkAndIncrementRecursionDepth()
out := p.VisitCalc(tree)
p.decrementRecursionDepth()
return out
case *gen.LogicalNotContext:
return p.VisitLogicalNot(tree.(*gen.LogicalNotContext))
case *gen.MemberExprContext:
return p.VisitMemberExpr(tree.(*gen.MemberExprContext))
case *gen.PrimaryExprContext:
return p.VisitPrimaryExpr(tree.(*gen.PrimaryExprContext))
case *gen.SelectOrCallContext:
return p.VisitSelectOrCall(tree.(*gen.SelectOrCallContext))
return p.VisitLogicalNot(tree)
case *gen.IdentOrGlobalCallContext:
return p.VisitIdentOrGlobalCall(tree)
case *gen.SelectContext:
p.checkAndIncrementRecursionDepth()
out := p.VisitSelect(tree)
p.decrementRecursionDepth()
return out
case *gen.MemberCallContext:
p.checkAndIncrementRecursionDepth()
out := p.VisitMemberCall(tree)
p.decrementRecursionDepth()
return out
case *gen.MapInitializerListContext:
return p.VisitMapInitializerList(tree.(*gen.MapInitializerListContext))
return p.VisitMapInitializerList(tree)
case *gen.NegateContext:
return p.VisitNegate(tree.(*gen.NegateContext))
return p.VisitNegate(tree)
case *gen.IndexContext:
return p.VisitIndex(tree.(*gen.IndexContext))
p.checkAndIncrementRecursionDepth()
out := p.VisitIndex(tree)
p.decrementRecursionDepth()
return out
case *gen.UnaryContext:
return p.VisitUnary(tree.(*gen.UnaryContext))
return p.VisitUnary(tree)
case *gen.CreateListContext:
return p.VisitCreateList(tree.(*gen.CreateListContext))
return p.VisitCreateList(tree)
case *gen.CreateMessageContext:
return p.VisitCreateMessage(tree.(*gen.CreateMessageContext))
return p.VisitCreateMessage(tree)
case *gen.CreateStructContext:
return p.VisitCreateStruct(tree.(*gen.CreateStructContext))
return p.VisitCreateStruct(tree)
case *gen.IntContext:
return p.VisitInt(tree)
case *gen.UintContext:
return p.VisitUint(tree)
case *gen.DoubleContext:
return p.VisitDouble(tree)
case *gen.StringContext:
return p.VisitString(tree)
case *gen.BytesContext:
return p.VisitBytes(tree)
case *gen.BoolFalseContext:
return p.VisitBoolFalse(tree)
case *gen.BoolTrueContext:
return p.VisitBoolTrue(tree)
case *gen.NullContext:
return p.VisitNull(tree)
}
// Report at least one error if the parser reaches an unknown parse element.
// Typically, this happens if the parser has already encountered a syntax error elsewhere.
if len(p.errors.GetErrors()) == 0 {
txt := "<<nil>>"
if tree != nil {
txt = fmt.Sprintf("<<%T>>", tree)
if t != nil {
txt = fmt.Sprintf("<<%T>>", t)
}
return p.reportError(common.NoLocation, "unknown parse element encountered: %s", txt)
}
@ -411,12 +461,12 @@ func (p *parser) Visit(tree antlr.ParseTree) interface{} {
}
// Visit a parse tree produced by CELParser#start.
func (p *parser) VisitStart(ctx *gen.StartContext) interface{} {
func (p *parser) VisitStart(ctx *gen.StartContext) any {
return p.Visit(ctx.Expr())
}
// Visit a parse tree produced by CELParser#expr.
func (p *parser) VisitExpr(ctx *gen.ExprContext) interface{} {
func (p *parser) VisitExpr(ctx *gen.ExprContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
if ctx.GetOp() == nil {
return result
@ -428,11 +478,8 @@ func (p *parser) VisitExpr(ctx *gen.ExprContext) interface{} {
}
// Visit a parse tree produced by CELParser#conditionalOr.
func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) interface{} {
func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
if ctx.GetOps() == nil {
return result
}
b := newBalancer(p.helper, operators.LogicalOr, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
@ -447,11 +494,8 @@ func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) interface{} {
}
// Visit a parse tree produced by CELParser#conditionalAnd.
func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) interface{} {
func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
if ctx.GetOps() == nil {
return result
}
b := newBalancer(p.helper, operators.LogicalAnd, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
@ -466,10 +510,7 @@ func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) interface{}
}
// Visit a parse tree produced by CELParser#relation.
func (p *parser) VisitRelation(ctx *gen.RelationContext) interface{} {
if ctx.Calc() != nil {
return p.Visit(ctx.Calc())
}
func (p *parser) VisitRelation(ctx *gen.RelationContext) any {
opText := ""
if ctx.GetOp() != nil {
opText = ctx.GetOp().GetText()
@ -484,10 +525,7 @@ func (p *parser) VisitRelation(ctx *gen.RelationContext) interface{} {
}
// Visit a parse tree produced by CELParser#calc.
func (p *parser) VisitCalc(ctx *gen.CalcContext) interface{} {
if ctx.Unary() != nil {
return p.Visit(ctx.Unary())
}
func (p *parser) VisitCalc(ctx *gen.CalcContext) any {
opText := ""
if ctx.GetOp() != nil {
opText = ctx.GetOp().GetText()
@ -501,27 +539,12 @@ func (p *parser) VisitCalc(ctx *gen.CalcContext) interface{} {
return p.reportError(ctx, "operator not found")
}
func (p *parser) VisitUnary(ctx *gen.UnaryContext) interface{} {
func (p *parser) VisitUnary(ctx *gen.UnaryContext) any {
return p.helper.newLiteralString(ctx, "<<error>>")
}
// Visit a parse tree produced by CELParser#MemberExpr.
func (p *parser) VisitMemberExpr(ctx *gen.MemberExprContext) interface{} {
switch ctx.Member().(type) {
case *gen.PrimaryExprContext:
return p.VisitPrimaryExpr(ctx.Member().(*gen.PrimaryExprContext))
case *gen.SelectOrCallContext:
return p.VisitSelectOrCall(ctx.Member().(*gen.SelectOrCallContext))
case *gen.IndexContext:
return p.VisitIndex(ctx.Member().(*gen.IndexContext))
case *gen.CreateMessageContext:
return p.VisitCreateMessage(ctx.Member().(*gen.CreateMessageContext))
}
return p.reportError(ctx, "unsupported simple expression")
}
// Visit a parse tree produced by CELParser#LogicalNot.
func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) interface{} {
func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) any {
if len(ctx.GetOps())%2 == 0 {
return p.Visit(ctx.Member())
}
@ -530,7 +553,7 @@ func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) interface{} {
return p.globalCallOrMacro(opID, operators.LogicalNot, target)
}
func (p *parser) VisitNegate(ctx *gen.NegateContext) interface{} {
func (p *parser) VisitNegate(ctx *gen.NegateContext) any {
if len(ctx.GetOps())%2 == 0 {
return p.Visit(ctx.Member())
}
@ -539,60 +562,77 @@ func (p *parser) VisitNegate(ctx *gen.NegateContext) interface{} {
return p.globalCallOrMacro(opID, operators.Negate, target)
}
// Visit a parse tree produced by CELParser#SelectOrCall.
func (p *parser) VisitSelectOrCall(ctx *gen.SelectOrCallContext) interface{} {
// VisitSelect visits a parse tree produced by CELParser#Select.
func (p *parser) VisitSelect(ctx *gen.SelectContext) any {
operand := p.Visit(ctx.Member()).(*exprpb.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil || ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
}
id := ctx.GetId().GetText()
if ctx.GetOpt() != nil {
if !p.enableOptionalSyntax {
return p.reportError(ctx.GetOp(), "unsupported syntax '.?'")
}
return p.helper.newGlobalCall(
ctx.GetOp(),
operators.OptSelect,
operand,
p.helper.newLiteralString(ctx.GetId(), id))
}
return p.helper.newSelect(ctx.GetOp(), operand, id)
}
// VisitMemberCall visits a parse tree produced by CELParser#MemberCall.
func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any {
operand := p.Visit(ctx.Member()).(*exprpb.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil {
return p.helper.newExpr(ctx)
}
id := ctx.GetId().GetText()
if ctx.GetOpen() != nil {
opID := p.helper.id(ctx.GetOpen())
return p.receiverCallOrMacro(opID, id, operand, p.visitList(ctx.GetArgs())...)
}
return p.helper.newSelect(ctx.GetOp(), operand, id)
}
// Visit a parse tree produced by CELParser#PrimaryExpr.
func (p *parser) VisitPrimaryExpr(ctx *gen.PrimaryExprContext) interface{} {
switch ctx.Primary().(type) {
case *gen.NestedContext:
return p.VisitNested(ctx.Primary().(*gen.NestedContext))
case *gen.IdentOrGlobalCallContext:
return p.VisitIdentOrGlobalCall(ctx.Primary().(*gen.IdentOrGlobalCallContext))
case *gen.CreateListContext:
return p.VisitCreateList(ctx.Primary().(*gen.CreateListContext))
case *gen.CreateStructContext:
return p.VisitCreateStruct(ctx.Primary().(*gen.CreateStructContext))
case *gen.ConstantLiteralContext:
return p.VisitConstantLiteral(ctx.Primary().(*gen.ConstantLiteralContext))
}
return p.reportError(ctx, "invalid primary expression")
opID := p.helper.id(ctx.GetOpen())
return p.receiverCallOrMacro(opID, id, operand, p.visitExprList(ctx.GetArgs())...)
}
// Visit a parse tree produced by CELParser#Index.
func (p *parser) VisitIndex(ctx *gen.IndexContext) interface{} {
func (p *parser) VisitIndex(ctx *gen.IndexContext) any {
target := p.Visit(ctx.Member()).(*exprpb.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
}
opID := p.helper.id(ctx.GetOp())
index := p.Visit(ctx.GetIndex()).(*exprpb.Expr)
return p.globalCallOrMacro(opID, operators.Index, target, index)
operator := operators.Index
if ctx.GetOpt() != nil {
if !p.enableOptionalSyntax {
return p.reportError(ctx.GetOp(), "unsupported syntax '[?'")
}
operator = operators.OptIndex
}
return p.globalCallOrMacro(opID, operator, target, index)
}
// Visit a parse tree produced by CELParser#CreateMessage.
func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) interface{} {
target := p.Visit(ctx.Member()).(*exprpb.Expr)
objID := p.helper.id(ctx.GetOp())
if messageName, found := p.extractQualifiedName(target); found {
entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
return p.helper.newObject(objID, messageName, entries...)
func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any {
messageName := ""
for _, id := range ctx.GetIds() {
if len(messageName) != 0 {
messageName += "."
}
messageName += id.GetText()
}
return p.helper.newExpr(objID)
if ctx.GetLeadingDot() != nil {
messageName = "." + messageName
}
objID := p.helper.id(ctx.GetOp())
entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
return p.helper.newObject(objID, messageName, entries...)
}
// Visit a parse tree of field initializers.
func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) interface{} {
func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) any {
if ctx == nil || ctx.GetFields() == nil {
// This is the result of a syntax error handled elswhere, return empty.
return []*exprpb.Expr_CreateStruct_Entry{}
@ -607,15 +647,27 @@ func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext
return []*exprpb.Expr_CreateStruct_Entry{}
}
initID := p.helper.id(cols[i])
optField := f.(*gen.OptFieldContext)
optional := optField.GetOpt() != nil
if !p.enableOptionalSyntax && optional {
p.reportError(optField, "unsupported syntax '?'")
continue
}
// The field may be empty due to a prior error.
id := optField.IDENTIFIER()
if id == nil {
return []*exprpb.Expr_CreateStruct_Entry{}
}
fieldName := id.GetText()
value := p.Visit(vals[i]).(*exprpb.Expr)
field := p.helper.newObjectField(initID, f.GetText(), value)
field := p.helper.newObjectField(initID, fieldName, value, optional)
result[i] = field
}
return result
}
// Visit a parse tree produced by CELParser#IdentOrGlobalCall.
func (p *parser) VisitIdentOrGlobalCall(ctx *gen.IdentOrGlobalCallContext) interface{} {
func (p *parser) VisitIdentOrGlobalCall(ctx *gen.IdentOrGlobalCallContext) any {
identName := ""
if ctx.GetLeadingDot() != nil {
identName = "."
@ -632,24 +684,20 @@ func (p *parser) VisitIdentOrGlobalCall(ctx *gen.IdentOrGlobalCallContext) inter
identName += id
if ctx.GetOp() != nil {
opID := p.helper.id(ctx.GetOp())
return p.globalCallOrMacro(opID, identName, p.visitList(ctx.GetArgs())...)
return p.globalCallOrMacro(opID, identName, p.visitExprList(ctx.GetArgs())...)
}
return p.helper.newIdent(ctx.GetId(), identName)
}
// Visit a parse tree produced by CELParser#Nested.
func (p *parser) VisitNested(ctx *gen.NestedContext) interface{} {
return p.Visit(ctx.GetE())
}
// Visit a parse tree produced by CELParser#CreateList.
func (p *parser) VisitCreateList(ctx *gen.CreateListContext) interface{} {
func (p *parser) VisitCreateList(ctx *gen.CreateListContext) any {
listID := p.helper.id(ctx.GetOp())
return p.helper.newList(listID, p.visitList(ctx.GetElems())...)
elems, optionals := p.visitListInit(ctx.GetElems())
return p.helper.newList(listID, elems, optionals...)
}
// Visit a parse tree produced by CELParser#CreateStruct.
func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) interface{} {
func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any {
structID := p.helper.id(ctx.GetOp())
entries := []*exprpb.Expr_CreateStruct_Entry{}
if ctx.GetEntries() != nil {
@ -658,31 +706,8 @@ func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) interface{} {
return p.helper.newMap(structID, entries...)
}
// Visit a parse tree produced by CELParser#ConstantLiteral.
func (p *parser) VisitConstantLiteral(ctx *gen.ConstantLiteralContext) interface{} {
switch ctx.Literal().(type) {
case *gen.IntContext:
return p.VisitInt(ctx.Literal().(*gen.IntContext))
case *gen.UintContext:
return p.VisitUint(ctx.Literal().(*gen.UintContext))
case *gen.DoubleContext:
return p.VisitDouble(ctx.Literal().(*gen.DoubleContext))
case *gen.StringContext:
return p.VisitString(ctx.Literal().(*gen.StringContext))
case *gen.BytesContext:
return p.VisitBytes(ctx.Literal().(*gen.BytesContext))
case *gen.BoolFalseContext:
return p.VisitBoolFalse(ctx.Literal().(*gen.BoolFalseContext))
case *gen.BoolTrueContext:
return p.VisitBoolTrue(ctx.Literal().(*gen.BoolTrueContext))
case *gen.NullContext:
return p.VisitNull(ctx.Literal().(*gen.NullContext))
}
return p.reportError(ctx, "invalid literal")
}
// Visit a parse tree produced by CELParser#mapInitializerList.
func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) interface{} {
func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any {
if ctx == nil || ctx.GetKeys() == nil {
// This is the result of a syntax error handled elswhere, return empty.
return []*exprpb.Expr_CreateStruct_Entry{}
@ -697,16 +722,22 @@ func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) int
// This is the result of a syntax error detected elsewhere.
return []*exprpb.Expr_CreateStruct_Entry{}
}
key := p.Visit(keys[i]).(*exprpb.Expr)
optKey := keys[i]
optional := optKey.GetOpt() != nil
if !p.enableOptionalSyntax && optional {
p.reportError(optKey, "unsupported syntax '?'")
continue
}
key := p.Visit(optKey.GetE()).(*exprpb.Expr)
value := p.Visit(vals[i]).(*exprpb.Expr)
entry := p.helper.newMapEntry(colID, key, value)
entry := p.helper.newMapEntry(colID, key, value, optional)
result[i] = entry
}
return result
}
// Visit a parse tree produced by CELParser#Int.
func (p *parser) VisitInt(ctx *gen.IntContext) interface{} {
func (p *parser) VisitInt(ctx *gen.IntContext) any {
text := ctx.GetTok().GetText()
base := 10
if strings.HasPrefix(text, "0x") {
@ -724,7 +755,7 @@ func (p *parser) VisitInt(ctx *gen.IntContext) interface{} {
}
// Visit a parse tree produced by CELParser#Uint.
func (p *parser) VisitUint(ctx *gen.UintContext) interface{} {
func (p *parser) VisitUint(ctx *gen.UintContext) any {
text := ctx.GetTok().GetText()
// trim the 'u' designator included in the uint literal.
text = text[:len(text)-1]
@ -741,7 +772,7 @@ func (p *parser) VisitUint(ctx *gen.UintContext) interface{} {
}
// Visit a parse tree produced by CELParser#Double.
func (p *parser) VisitDouble(ctx *gen.DoubleContext) interface{} {
func (p *parser) VisitDouble(ctx *gen.DoubleContext) any {
txt := ctx.GetTok().GetText()
if ctx.GetSign() != nil {
txt = ctx.GetSign().GetText() + txt
@ -755,42 +786,66 @@ func (p *parser) VisitDouble(ctx *gen.DoubleContext) interface{} {
}
// Visit a parse tree produced by CELParser#String.
func (p *parser) VisitString(ctx *gen.StringContext) interface{} {
func (p *parser) VisitString(ctx *gen.StringContext) any {
s := p.unquote(ctx, ctx.GetText(), false)
return p.helper.newLiteralString(ctx, s)
}
// Visit a parse tree produced by CELParser#Bytes.
func (p *parser) VisitBytes(ctx *gen.BytesContext) interface{} {
func (p *parser) VisitBytes(ctx *gen.BytesContext) any {
b := []byte(p.unquote(ctx, ctx.GetTok().GetText()[1:], true))
return p.helper.newLiteralBytes(ctx, b)
}
// Visit a parse tree produced by CELParser#BoolTrue.
func (p *parser) VisitBoolTrue(ctx *gen.BoolTrueContext) interface{} {
func (p *parser) VisitBoolTrue(ctx *gen.BoolTrueContext) any {
return p.helper.newLiteralBool(ctx, true)
}
// Visit a parse tree produced by CELParser#BoolFalse.
func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) interface{} {
func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) any {
return p.helper.newLiteralBool(ctx, false)
}
// Visit a parse tree produced by CELParser#Null.
func (p *parser) VisitNull(ctx *gen.NullContext) interface{} {
func (p *parser) VisitNull(ctx *gen.NullContext) any {
return p.helper.newLiteral(ctx,
&exprpb.Constant{
ConstantKind: &exprpb.Constant_NullValue{
NullValue: structpb.NullValue_NULL_VALUE}})
}
func (p *parser) visitList(ctx gen.IExprListContext) []*exprpb.Expr {
func (p *parser) visitExprList(ctx gen.IExprListContext) []*exprpb.Expr {
if ctx == nil {
return []*exprpb.Expr{}
}
return p.visitSlice(ctx.GetE())
}
func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int32) {
if ctx == nil {
return []*exprpb.Expr{}, []int32{}
}
elements := ctx.GetElems()
result := make([]*exprpb.Expr, len(elements))
optionals := []int32{}
for i, e := range elements {
ex := p.Visit(e.GetE()).(*exprpb.Expr)
if ex == nil {
return []*exprpb.Expr{}, []int32{}
}
result[i] = ex
if e.GetOpt() != nil {
if !p.enableOptionalSyntax {
p.reportError(e.GetOpt(), "unsupported syntax '?'")
continue
}
optionals = append(optionals, int32(i))
}
}
return result, optionals
}
func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr {
if expressions == nil {
return []*exprpb.Expr{}
@ -803,26 +858,7 @@ func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr {
return result
}
func (p *parser) extractQualifiedName(e *exprpb.Expr) (string, bool) {
if e == nil {
return "", false
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
return e.GetIdentExpr().GetName(), true
case *exprpb.Expr_SelectExpr:
s := e.GetSelectExpr()
if prefix, found := p.extractQualifiedName(s.GetOperand()); found {
return prefix + "." + s.GetField(), true
}
}
// TODO: Add a method to Source to get location from character offset.
location := p.helper.getLocation(e.GetId())
p.reportError(location, "expected a qualified name")
return "", false
}
func (p *parser) unquote(ctx interface{}, value string, isBytes bool) string {
func (p *parser) unquote(ctx any, value string, isBytes bool) string {
text, err := unescape(value, isBytes)
if err != nil {
p.reportError(ctx, "%s", err.Error())
@ -831,7 +867,7 @@ func (p *parser) unquote(ctx interface{}, value string, isBytes bool) string {
return text
}
func (p *parser) reportError(ctx interface{}, format string, args ...interface{}) *exprpb.Expr {
func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr {
var location common.Location
switch ctx.(type) {
case common.Location:
@ -847,10 +883,24 @@ func (p *parser) reportError(ctx interface{}, format string, args ...interface{}
}
// ANTLR Parse listener implementations
func (p *parser) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
// TODO: Snippet
func (p *parser) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
l := p.helper.source.NewLocation(line, column)
p.errors.syntaxError(l, msg)
// Hack to keep existing error messages consistent with previous versions of CEL when a reserved word
// is used as an identifier. This behavior needs to be overhauled to provide consistent, normalized error
// messages out of ANTLR to prevent future breaking changes related to error message content.
if strings.Contains(msg, "no viable alternative") {
msg = reservedIdentifier.ReplaceAllString(msg, mismatchedReservedIdentifier)
}
// Ensure that no more than 100 syntax errors are reported as this will halt attempts to recover from a
// seriously broken expression.
if p.errorReports < p.errorReportingLimit {
p.errorReports++
p.errors.syntaxError(l, msg)
} else {
tme := &tooManyErrors{errorReportingLimit: p.errorReportingLimit}
p.errors.syntaxError(l, tme.Error())
panic(tme)
}
}
func (p *parser) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
@ -892,14 +942,95 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr,
eh.parserHelper = p.helper
eh.id = exprID
expr, err := macro.Expander()(eh, target, args)
// An error indicates that the macro was matched, but the arguments were not well-formed.
if err != nil {
if err.Location != nil {
return p.reportError(err.Location, err.Message), true
}
return p.reportError(p.helper.getLocation(exprID), err.Message), true
}
// A nil value from the macro indicates that the macro implementation decided that
// an expansion should not be performed.
if expr == nil {
return nil, false
}
if p.populateMacroCalls {
p.helper.addMacroCall(expr.GetId(), function, target, args...)
}
return expr, true
}
func (p *parser) checkAndIncrementRecursionDepth() {
p.recursionDepth++
if p.recursionDepth > p.maxRecursionDepth {
panic(&recursionError{message: "max recursion depth exceeded"})
}
}
func (p *parser) decrementRecursionDepth() {
p.recursionDepth--
}
// unnest traverses down the left-hand side of the parse graph until it encounters the first compound
// parse node or the first leaf in the parse graph.
func unnest(tree antlr.ParseTree) antlr.ParseTree {
for tree != nil {
switch t := tree.(type) {
case *gen.ExprContext:
// conditionalOr op='?' conditionalOr : expr
if t.GetOp() != nil {
return t
}
// conditionalOr
tree = t.GetE()
case *gen.ConditionalOrContext:
// conditionalAnd (ops=|| conditionalAnd)*
if t.GetOps() != nil && len(t.GetOps()) > 0 {
return t
}
// conditionalAnd
tree = t.GetE()
case *gen.ConditionalAndContext:
// relation (ops=&& relation)*
if t.GetOps() != nil && len(t.GetOps()) > 0 {
return t
}
// relation
tree = t.GetE()
case *gen.RelationContext:
// relation op relation
if t.GetOp() != nil {
return t
}
// calc
tree = t.Calc()
case *gen.CalcContext:
// calc op calc
if t.GetOp() != nil {
return t
}
// unary
tree = t.Unary()
case *gen.MemberExprContext:
// member expands to one of: primary, select, index, or create message
tree = t.Member()
case *gen.PrimaryExprContext:
// primary expands to one of identifier, nested, create list, create struct, literal
tree = t.Primary()
case *gen.NestedContext:
// contains a nested 'expr'
tree = t.GetE()
case *gen.ConstantLiteralContext:
// expands to a primitive literal
tree = t.Literal()
default:
return t
}
}
return tree
}
var (
reservedIdentifier = regexp.MustCompile("no viable alternative at input '.(true|false|null)'")
mismatchedReservedIdentifier = "mismatched input '$1' expecting IDENTIFIER"
)

View File

@ -106,9 +106,15 @@ func (un *unparser) visitCall(expr *exprpb.Expr) error {
// ternary operator
case operators.Conditional:
return un.visitCallConditional(expr)
// optional select operator
case operators.OptSelect:
return un.visitOptSelect(expr)
// index operator
case operators.Index:
return un.visitCallIndex(expr)
// optional index operator
case operators.OptIndex:
return un.visitCallOptIndex(expr)
// unary operators
case operators.LogicalNot, operators.Negate:
return un.visitCallUnary(expr)
@ -218,6 +224,14 @@ func (un *unparser) visitCallFunc(expr *exprpb.Expr) error {
}
func (un *unparser) visitCallIndex(expr *exprpb.Expr) error {
return un.visitCallIndexInternal(expr, "[")
}
func (un *unparser) visitCallOptIndex(expr *exprpb.Expr) error {
return un.visitCallIndexInternal(expr, "[?")
}
func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error {
c := expr.GetCallExpr()
args := c.GetArgs()
nested := isBinaryOrTernaryOperator(args[0])
@ -225,7 +239,7 @@ func (un *unparser) visitCallIndex(expr *exprpb.Expr) error {
if err != nil {
return err
}
un.str.WriteString("[")
un.str.WriteString(op)
err = un.visit(args[1])
if err != nil {
return err
@ -262,6 +276,9 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error {
// represent the float using the minimum required digits
d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64)
un.str.WriteString(d)
if !strings.Contains(d, ".") {
un.str.WriteString(".0")
}
case *exprpb.Constant_Int64Value:
i := strconv.FormatInt(c.GetInt64Value(), 10)
un.str.WriteString(i)
@ -289,8 +306,15 @@ func (un *unparser) visitIdent(expr *exprpb.Expr) error {
func (un *unparser) visitList(expr *exprpb.Expr) error {
l := expr.GetListExpr()
elems := l.GetElements()
optIndices := make(map[int]bool, len(elems))
for _, idx := range l.GetOptionalIndices() {
optIndices[int(idx)] = true
}
un.str.WriteString("[")
for i, elem := range elems {
if optIndices[i] {
un.str.WriteString("?")
}
err := un.visit(elem)
if err != nil {
return err
@ -303,20 +327,32 @@ func (un *unparser) visitList(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitOptSelect(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
args := c.GetArgs()
operand := args[0]
field := args[1].GetConstExpr().GetStringValue()
return un.visitSelectInternal(operand, false, ".?", field)
}
func (un *unparser) visitSelect(expr *exprpb.Expr) error {
sel := expr.GetSelectExpr()
return un.visitSelectInternal(sel.GetOperand(), sel.GetTestOnly(), ".", sel.GetField())
}
func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op string, field string) error {
// handle the case when the select expression was generated by the has() macro.
if sel.GetTestOnly() {
if testOnly {
un.str.WriteString("has(")
}
nested := !sel.GetTestOnly() && isBinaryOrTernaryOperator(sel.GetOperand())
err := un.visitMaybeNested(sel.GetOperand(), nested)
nested := !testOnly && isBinaryOrTernaryOperator(operand)
err := un.visitMaybeNested(operand, nested)
if err != nil {
return err
}
un.str.WriteString(".")
un.str.WriteString(sel.GetField())
if sel.GetTestOnly() {
un.str.WriteString(op)
un.str.WriteString(field)
if testOnly {
un.str.WriteString(")")
}
return nil
@ -339,6 +375,9 @@ func (un *unparser) visitStructMsg(expr *exprpb.Expr) error {
un.str.WriteString("{")
for i, entry := range entries {
f := entry.GetFieldKey()
if entry.GetOptionalEntry() {
un.str.WriteString("?")
}
un.str.WriteString(f)
un.str.WriteString(": ")
v := entry.GetValue()
@ -360,6 +399,9 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
un.str.WriteString("{")
for i, entry := range entries {
k := entry.GetMapKey()
if entry.GetOptionalEntry() {
un.str.WriteString("?")
}
err := un.visit(k)
if err != nil {
return err
@ -492,11 +534,10 @@ func (un *unparser) writeOperatorWithWrapping(fun string, unmangled string) bool
un.str.WriteString(" ")
}
return true
} else {
un.str.WriteString(" ")
un.str.WriteString(unmangled)
un.str.WriteString(" ")
}
un.str.WriteString(" ")
un.str.WriteString(unmangled)
un.str.WriteString(" ")
return false
}

View File

@ -24,7 +24,7 @@ import (
"github.com/golang/protobuf/ptypes/any"
yaml "gopkg.in/yaml.v3"
extensions "github.com/google/gnostic/extensions"
extensions "github.com/google/gnostic-models/extensions"
)
// ExtensionHandler describes a binary that is called by the compiler to handle specification extensions.

View File

@ -22,7 +22,7 @@ import (
"gopkg.in/yaml.v3"
"github.com/google/gnostic/jsonschema"
"github.com/google/gnostic-models/jsonschema"
)
// compiler helper functions, usually called from generated code

Some files were not shown because too many files have changed in this diff Show More