rebase: update kubernetes to latest

updating the kubernetes release to the
latest in main go.mod

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna
2024-08-19 10:01:33 +02:00
committed by mergify[bot]
parent 63c4c05b35
commit 5a66991bb3
2173 changed files with 98906 additions and 61334 deletions

View File

@ -10,9 +10,12 @@ go_library(
"cel.go",
"decls.go",
"env.go",
"folding.go",
"io.go",
"inlining.go",
"library.go",
"macro.go",
"optimizer.go",
"options.go",
"program.go",
"validator.go",
@ -56,7 +59,11 @@ go_test(
"cel_test.go",
"decls_test.go",
"env_test.go",
"folding_test.go",
"io_test.go",
"inlining_test.go",
"optimizer_test.go",
"validator_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
@ -70,6 +77,7 @@ go_test(
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//ext:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",

View File

@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}
func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@ -38,26 +38,42 @@ type Source = common.Source
// Ast representing the checked or unchecked expression, its source, and related metadata such as
// source position information.
type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
source Source
impl *celast.AST
}
// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
return ast.impl
}
// Expr returns the proto serializable instance of the parsed/checked expression.
//
// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr()
// the result instead.
func (ast *Ast) Expr() *exprpb.Expr {
return ast.expr
if ast == nil {
return nil
}
pbExpr, _ := celast.ExprToProto(ast.impl.Expr())
return pbExpr
}
// IsChecked returns whether the Ast value has been successfully type-checked.
func (ast *Ast) IsChecked() bool {
return ast.typeMap != nil && len(ast.typeMap) > 0
if ast == nil {
return false
}
return ast.impl.IsChecked()
}
// SourceInfo returns character offset and newline position information about expression elements.
func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
return ast.info
if ast == nil {
return nil
}
pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo())
return pbInfo
}
// ResultType returns the output type of the expression if the Ast has been type-checked, else
@ -65,9 +81,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return chkdecls.Dyn
}
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
@ -79,16 +92,18 @@ func (ast *Ast) ResultType() *exprpb.Type {
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
if ast == nil {
return types.ErrorType
}
return t
return ast.impl.GetType(ast.impl.Expr().ID())
}
// Source returns a view of the input used to create the Ast. This source may be complete or
// constructed from the SourceInfo.
func (ast *Ast) Source() Source {
if ast == nil {
return nil
}
return ast.source
}
@ -198,29 +213,28 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
// Note, errors aren't currently possible on the Ast to ParsedExpr conversion.
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
res, errs := checker.Check(pe, ast.Source(), chk)
checked, errs := checker.Check(ast.impl, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
source: ast.Source(),
impl: checked}
// Avoid creating a validator config if it's not needed.
if len(e.validators) == 0 {
return ast, nil
}
// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
@ -230,9 +244,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
v.Validate(e, vConfig, checked, iss)
}
if iss.Err() != nil {
return nil, iss
@ -429,16 +443,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src Source) (*Ast, *Issues) {
res, errs := e.prsr.Parse(src)
parsed, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
// Manually create the Ast to ensure that the text source information is propagated on
// subsequent calls to Check.
return &Ast{
source: src,
expr: res.GetExpr(),
info: res.GetSourceInfo()}, nil
return &Ast{source: src, impl: parsed}, nil
}
// Program generates an evaluable instance of the Ast within the environment (Env).
@ -534,8 +543,9 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
return nil, err
}
@ -556,16 +566,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}
// configure applies a series of EnvOptions to the current environment.
@ -707,7 +711,7 @@ type Error = common.Error
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
info *celast.SourceInfo
}
// NewIssues returns an Issues struct from a common.Errors object.
@ -718,7 +722,7 @@ func NewIssues(errs *common.Errors) *Issues {
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
@ -768,30 +772,7 @@ func (i *Issues) String() string {
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}
// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...)
}
// getStdEnv lazy initializes the CEL standard environment.
@ -822,6 +803,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}
// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}
// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//

559
vendor/github.com/google/cel-go/cel/folding.go generated vendored Normal file
View File

@ -0,0 +1,559 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// ConstantFoldingOption defines a functional option for configuring constant folding.
type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)
// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
//
// Defaults to 100 if not set.
func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
opt.maxFoldIterations = limit
return opt, nil
}
}
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
folder := &constantFoldingOptimizer{
maxFoldIterations: defaultMaxConstantFoldIterations,
}
var err error
for _, o := range opts {
folder, err = o(folder)
if err != nil {
return nil, err
}
}
return folder, nil
}
type constantFoldingOptimizer struct {
maxFoldIterations int
}
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
// select statements and then evaluates them and replaces the call site with the literal result.
//
// Note: only values which can be represented as literals in CEL syntax are supported.
func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
// If the expression could be folded because it's a non-strict call, and the
// branches are pruned, continue to the next fold.
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
// target comprehension node with the literal value if the evaluation succeeds.
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
tryFold(ctx, a, compre)
}
// If the output is a list, map, or struct which contains optional entries, then prune it
// to make sure that the optionals, if resolved, do not surface in the output literal.
pruneOptionalElements(ctx, root)
// Ensure that all intermediate values in the folded expression can be represented as valid
// CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
// to avoid extra allocations during this final pass through the AST.
ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() != ast.LiteralKind {
return
}
val := e.AsLiteral()
adapted, err := adaptLiteral(ctx, val)
if err != nil {
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
ctx.UpdateExpr(e, adapted)
}))
return a
}
// tryFold attempts to evaluate a sub-expression to a literal.
//
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
// the method will return an error.
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// Assume all context is needed to evaluate the expression.
subAST := &Ast{
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
}
prg, err := ctx.Program(subAST)
if err != nil {
return err
}
out, _, err := prg.Eval(NoVars())
if err != nil {
return err
}
// Update the fold expression to be a literal.
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}
// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
// constant folding step reimplements a portion of the evaluator.
func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
call := expr.AsCall()
args := call.Args()
switch call.FunctionName() {
case operators.LogicalAnd, operators.LogicalOr:
return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
case operators.Conditional:
cond := args[0]
truthy := args[1]
falsy := args[2]
if cond.Kind() != ast.LiteralKind {
return false
}
if cond.AsLiteral() == types.True {
ctx.UpdateExpr(expr, truthy)
} else {
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
}
}
return false
}
func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
shortcircuit := types.False
skip := types.True
if function == operators.LogicalOr {
shortcircuit = types.True
skip = types.False
}
newArgs := []ast.Expr{}
for _, arg := range args {
if arg.Kind() != ast.LiteralKind {
newArgs = append(newArgs, arg)
continue
}
if arg.AsLiteral() == skip {
continue
}
if arg.AsLiteral() == shortcircuit {
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
ctx.UpdateExpr(expr, newArgs[0])
return true
}
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}
// pruneOptionalElements works from the bottom up to resolve optional elements within
// aggregate literals.
//
// Note, many aggregate literals will be resolved as arguments to functions or select
// statements, so this method exists to handle the case where the literal could not be
// fully resolved or exists outside of a call, select, or comprehension context.
func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher)
for _, lit := range aggregateLiterals {
switch lit.Kind() {
case ast.ListKind:
pruneOptionalListElements(ctx, lit)
case ast.MapKind:
pruneOptionalMapEntries(ctx, lit)
case ast.StructKind:
pruneOptionalStructFields(ctx, lit)
}
}
}
func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
l := e.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
if len(optIndices) == 0 {
return
}
updatedElems := []ast.Expr{}
updatedIndices := []int32{}
newOptIndex := -1
for _, e := range elems {
newOptIndex++
if !l.IsOptional(int32(newOptIndex)) {
updatedElems = append(updatedElems, e)
continue
}
if e.Kind() != ast.LiteralKind {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
optElemVal, ok := e.AsLiteral().(*types.Optional)
if !ok {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
if !optElemVal.HasValue() {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}
func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
m := e.AsMap()
entries := m.Entries()
updatedEntries := []ast.EntryExpr{}
modified := false
for _, e := range entries {
entry := e.AsMapEntry()
key := entry.Key()
val := entry.Value()
// If the entry is not optional, or the value-side of the optional hasn't
// been resolved to a literal, then preserve the entry as-is.
if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
updatedEntries = append(updatedEntries, e)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedEntries = append(updatedEntries, e)
continue
}
// When the key is not a literal, but the value is, then it needs to be
// restored to an optional value.
if key.Kind() != ast.LiteralKind {
undoOptVal, err := adaptLiteral(ctx, optElemVal)
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}
func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
s := e.AsStruct()
fields := s.Fields()
updatedFields := []ast.EntryExpr{}
modified := false
for _, f := range fields {
field := f.AsStructField()
val := field.Value()
if !field.IsOptional() || val.Kind() != ast.LiteralKind {
updatedFields = append(updatedFields, f)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedFields = append(updatedFields, f)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}
// adaptLiteral converts a runtime CEL value to its equivalent literal expression.
//
// For strongly typed values, the type-provider will be used to reconstruct the fields
// which are present in the literal and their equivalent initialization values.
func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
switch t := val.Type().(type) {
case *types.Type:
switch t {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
return ctx.NewLiteral(val), nil
case types.DurationType:
return ctx.NewCall(
overloads.TypeConvertDuration,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.TimestampType:
return ctx.NewCall(
overloads.TypeConvertTimestamp,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.OptionalType:
opt := val.(*types.Optional)
if !opt.HasValue() {
return ctx.NewCall("optional.none"), nil
}
target, err := adaptLiteral(ctx, opt.GetValue())
if err != nil {
return nil, err
}
return ctx.NewCall("optional.of", target), nil
case types.TypeType:
return ctx.NewIdent(val.(*types.Type).TypeName()), nil
case types.ListType:
l, ok := val.(traits.Lister)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
elems := make([]ast.Expr, l.Size().(types.Int))
idx := 0
it := l.Iterator()
for it.HasNext() == types.True {
elemVal := it.Next()
elemExpr, err := adaptLiteral(ctx, elemVal)
if err != nil {
return nil, err
}
elems[idx] = elemExpr
idx++
}
return ctx.NewList(elems, []int32{}), nil
case types.MapType:
m, ok := val.(traits.Mapper)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
entries := make([]ast.EntryExpr, m.Size().(types.Int))
idx := 0
it := m.Iterator()
for it.HasNext() == types.True {
keyVal := it.Next()
keyExpr, err := adaptLiteral(ctx, keyVal)
if err != nil {
return nil, err
}
valVal := m.Get(keyVal)
valExpr, err := adaptLiteral(ctx, valVal)
if err != nil {
return nil, err
}
entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false)
idx++
}
return ctx.NewMap(entries), nil
default:
provider := ctx.CELTypeProvider()
fields, found := provider.FindStructFieldNames(t.TypeName())
if !found {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
tester := val.(traits.FieldTester)
indexer := val.(traits.Indexer)
fieldInits := []ast.EntryExpr{}
for _, f := range fields {
field := types.String(f)
if tester.IsSet(field) != types.True {
continue
}
fieldVal := indexer.Get(field)
fieldExpr, err := adaptLiteral(ctx, fieldVal)
if err != nil {
return nil, err
}
fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false))
}
return ctx.NewStruct(t.TypeName(), fieldInits), nil
}
}
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
// constantExprMatcher matches calls, select statements, and comprehensions whose arguments
// are all constant scalar or aggregate literal values.
//
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
case ast.SelectKind:
sel := e.AsSelect() // guaranteed to be a navigable value
return constantMatcher(sel.Operand().(ast.NavigableExpr))
case ast.ComprehensionKind:
if isNestedComprehension(e) {
return false
}
vars := map[string]bool{}
constantExprs := true
visitor := ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() == ast.ComprehensionKind {
nested := e.AsComprehension()
vars[nested.AccuVar()] = true
vars[nested.IterVar()] = true
}
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
default:
return false
}
}
// constantCallMatcher identifies strict and non-strict calls which can be folded.
func constantCallMatcher(e ast.NavigableExpr) bool {
call := e.AsCall()
children := e.Children()
fnName := call.FunctionName()
if fnName == operators.LogicalAnd {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.LogicalOr {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.Conditional {
cond := children[0]
if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType {
return true
}
}
if fnName == operators.In {
haystack := children[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
return true
}
needle := children[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
return true
}
}
}
}
// convert all other calls with constant arguments
for _, child := range children {
if !constantMatcher(child) {
return false
}
}
return true
}
func isNestedComprehension(e ast.NavigableExpr) bool {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.ComprehensionKind {
return true
}
parent, found = parent.Parent()
}
return false
}
func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind
}
var (
constantMatcher = ast.ConstantValueMatcher()
)
const (
defaultMaxConstantFoldIterations = 100
)

228
vendor/github.com/google/cel-go/cel/inlining.go generated vendored Normal file
View File

@ -0,0 +1,228 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
)
// InlineVariable holds a variable name to be matched and an AST representing
// the expression graph which should be used to replace it.
type InlineVariable struct {
name string
alias string
def *ast.AST
}
// Name returns the qualified variable or field selection to replace.
func (v *InlineVariable) Name() string {
return v.name
}
// Alias returns the alias to use when performing cel.bind() calls during inlining.
func (v *InlineVariable) Alias() string {
return v.alias
}
// Expr returns the inlined expression value.
func (v *InlineVariable) Expr() ast.Expr {
return v.def.Expr()
}
// Type indicates the inlined expression type.
func (v *InlineVariable) Type() *Type {
return v.def.GetType(v.def.Expr().ID())
}
// NewInlineVariable declares a variable name to be replaced by a checked expression.
func NewInlineVariable(name string, definition *Ast) *InlineVariable {
return NewInlineVariableWithAlias(name, name, definition)
}
// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression.
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
}
// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
//
// If a variable occurs one time, the variable is replaced by the inline definition. If the
// variable occurs more than once, the variable occurences are replaced by a cel.bind() call.
func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer {
return &inliningOptimizer{variables: inlineVars}
}
type inliningOptimizer struct {
variables []*InlineVariable
}
func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
for _, inlineVar := range opt.variables {
matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name()))
// Skip cases where the variable isn't in the expression graph
if len(matches) == 0 {
continue
}
// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
}
// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = root
lcaAncestorCount := 0
ancestors := map[int64]int{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
parent, found := match, true
for found {
ancestorCount, hasAncestor := ancestors[parent.ID()]
if !hasAncestor {
ancestors[parent.ID()] = 1
parent, found = parent.Parent()
continue
}
if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) {
lca = parent
lcaAncestorCount = ancestorCount
}
ancestors[parent.ID()] = ancestorCount + 1
parent, found = parent.Parent()
}
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
}
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.SetMacroCall(lca.ID(), bindMacro)
}
return a
}
// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
ctx.UpdateExpr(prev, inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
ctx.UpdateExpr(prev, inlined)
}
}
// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe
// expression appropriate for the inlined type, if possible.
//
// If the rewrite is not possible an error is reported at the inline expression site.
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
if inlined.Kind() == ast.SelectKind {
presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.SetMacroCall(prev.ID(), hasMacro)
return
}
ctx.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
))
return
}
ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType)
}
// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression
// being replaced occurs within a presence test. Value types with a size() method or field selection
// support can be bound.
//
// In future iterations, support may also be added for indexer types which can be rewritten as an `in`
// expression; however, this would imply a rewrite of the inlined expression that may not be necessary
// in most cases.
func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool {
if inlinedType.IsAssignableType(NullType) ||
inlinedType.HasTrait(traits.SizerType) {
return true
}
for _, m := range matches {
if m.Kind() != ast.SelectKind {
continue
}
sel := m.AsSelect()
if sel.IsTestOnly() {
return false
}
}
return true
}
// matchVariable matches simple identifiers, select expressions, and presence test expressions
// which match the (potentially) qualified variable name provided as input.
//
// Note, this function does not support inlining against select expressions which includes optional
// field selection. This may be a future refinement.
func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
return true
}
if e.Kind() == ast.SelectKind {
sel := e.AsSelect()
// While the `ToQualifiedName` call could take the select directly, this
// would skip presence tests from possible matches, which we would like
// to include.
qualName, found := containers.ToQualifiedName(sel.Operand())
return found && qualName+"."+sel.FieldName() == varName
}
return false
}
}

View File

@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
checked, err := ast.ToAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
return &Ast{source: src, impl: checked}, nil
}
// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@ -67,13 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}
return ast.CheckedASTToCheckedExpr(cAst)
return ast.ToProto(a.impl)
}
// ParsedExprToAst converts a parsed expression proto message to an Ast.
@ -89,18 +77,12 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast {
//
// Prefer ParsedExprToAst if loading expressions from storage.
func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast {
si := parsedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo())
if src == nil {
src = common.NewInfoSource(si)
}
return &Ast{
expr: parsedExpr.GetExpr(),
info: si,
source: src,
src = common.NewInfoSource(parsedExpr.GetSourceInfo())
}
e, _ := ast.ProtoToExpr(parsedExpr.GetExpr())
return &Ast{source: src, impl: ast.NewAST(e, info)}
}
// AstToParsedExpr converts an Ast to an protobuf ParsedExpr value.
@ -116,9 +98,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
expr := a.Expr()
info := a.SourceInfo()
return parser.Unparse(expr, info)
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
}
// RefValueToValue converts between ref.Val and api.expr.Value.

View File

@ -20,6 +20,7 @@ import (
"strings"
"time"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
@ -28,8 +29,6 @@ import (
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Types(types.OptionalType),
// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
Macros(ReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption {
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewCall(optionalOfFunc,
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}
func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}

View File

@ -15,6 +15,11 @@
package cel
import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -26,7 +31,14 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value.
type MacroFactory = parser.MacroExpander
// MacroExprFactory assists with the creation of Expr values in a manner which is consistent
// the internal semantics and id generation behaviors of the parser and checker libraries.
type MacroExprFactory = parser.ExprHelper
// MacroExpander converts a call and its associated arguments into a protobuf Expr representation.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
@ -36,48 +48,197 @@ type Macro = parser.Macro
// and produces as output an Expr ast node.
//
// Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil.
type MacroExpander = parser.MacroExpander
type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error)
// MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree.
type MacroExprHelper = parser.ExprHelper
// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is
// consistent with the source position and expression id generation code leveraged by both
// the parser and type-checker.
type MacroExprHelper interface {
// Copy the input expression with a brand new set of identifiers.
Copy(*exprpb.Expr) *exprpb.Expr
// LiteralBool creates an Expr value for a bool literal.
LiteralBool(value bool) *exprpb.Expr
// LiteralBytes creates an Expr value for a byte literal.
LiteralBytes(value []byte) *exprpb.Expr
// LiteralDouble creates an Expr value for double literal.
LiteralDouble(value float64) *exprpb.Expr
// LiteralInt creates an Expr value for an int literal.
LiteralInt(value int64) *exprpb.Expr
// LiteralString creates am Expr value for a string literal.
LiteralString(value string) *exprpb.Expr
// LiteralUint creates an Expr value for a uint literal.
LiteralUint(value uint64) *exprpb.Expr
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
NewList(elems ...*exprpb.Expr) *exprpb.Expr
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewMapEntry creates a Map Entry for the key, value pair.
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr
// Ident creates an identifier Expr value.
Ident(name string) *exprpb.Expr
// AccuIdent returns an accumulator identifier for use with comprehension results.
AccuIdent() *exprpb.Expr
// GlobalCall creates a function call Expr value for a global (free) function.
GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr
// ReceiverCall creates a function call Expr value for a receiver-style function.
ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr
// Select create a field traversal Expr value.
Select(operand *exprpb.Expr, field string) *exprpb.Expr
// OffsetLocation returns the Location of the expression identifier.
OffsetLocation(exprID int64) common.Location
// NewError associates an error message with a given expression id.
NewError(exprID int64, message string) *Error
}
// GlobalMacro creates a Macro for a global function with the specified arg count.
func GlobalMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewGlobalMacro(function, argCount, factory)
}
// ReceiverMacro creates a Macro for a receiver function matching the specified arg count.
func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewReceiverMacro(function, argCount, factory)
}
// GlobalVarArgMacro creates a Macro for a global function with a variable arg count.
func GlobalVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewGlobalVarArgMacro(function, factory)
}
// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
func ReceiverVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewReceiverVarArgMacro(function, factory)
}
// NewGlobalMacro creates a Macro for a global function with the specified arg count.
//
// Deprecated: use GlobalMacro
func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewGlobalMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalMacro(function, argCount, expand.Expander)
}
// NewReceiverMacro creates a Macro for a receiver function matching the specified arg count.
//
// Deprecated: use ReceiverMacro
func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewReceiverMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverMacro(function, argCount, expand.Expander)
}
// NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count.
//
// Deprecated: use GlobalVarArgMacro
func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewGlobalVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalVarArgMacro(function, expand.Expander)
}
// NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
//
// Deprecated: use ReceiverVarArgMacro
func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewReceiverVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverVarArgMacro(function, expand.Expander)
}
// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field)
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeHas(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
arg, err := adaptToExpr(args[0])
if err != nil {
return nil, err
}
if arg.Kind() == ast.SelectKind {
s := arg.AsSelect()
return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName()))
}
return nil, ph.NewError(arg.ID(), "invalid argument to has() macro")
}
// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExists(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExistsOne(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the
@ -91,14 +252,30 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeMap(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeFilter(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
var (
@ -142,3 +319,258 @@ var (
// NoMacros provides an alias to an empty list of macros
NoMacros = []Macro{}
)
type adaptingExpander struct {
legacyExpander MacroExpander
}
func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
var legacyTarget *exprpb.Expr = nil
var err *Error = nil
if target != nil {
legacyTarget, err = adaptToProto(target)
if err != nil {
return nil, err
}
}
legacyArgs := make([]*exprpb.Expr, len(args))
for i, arg := range args {
legacyArgs[i], err = adaptToProto(arg)
if err != nil {
return nil, err
}
}
ah := &adaptingHelper{modernHelper: eh}
legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs)
if err != nil {
return nil, err
}
ex, err := adaptToExpr(legacyExpr)
if err != nil {
return nil, err
}
return ex, nil
}
func wrapErr(id int64, message string, err error) *common.Error {
return &common.Error{
Location: common.NoLocation,
Message: fmt.Sprintf("%s: %v", message, err),
ExprID: id,
}
}
type adaptingHelper struct {
modernHelper parser.ExprHelper
}
// Copy the input expression with a brand new set of identifiers.
func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e)))
}
// LiteralBool creates an Expr value for a bool literal.
func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value)))
}
// LiteralBytes creates an Expr value for a byte literal.
func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value)))
}
// LiteralDouble creates an Expr value for double literal.
func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value)))
}
// LiteralInt creates an Expr value for an int literal.
func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value)))
}
// LiteralString creates am Expr value for a string literal.
func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value)))
}
// LiteralUint creates an Expr value for a uint literal.
func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value)))
}
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...))
}
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(entries))
for i, e := range entries {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...))
}
// NewMapEntry creates a Map Entry for the key, value pair.
func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional))
}
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(fieldInits))
for i, e := range fieldInits {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...))
}
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional))
}
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
func (ah *adaptingHelper) Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewComprehension(
mustAdaptToExpr(iterRange),
iterVar,
accuVar,
mustAdaptToExpr(accuInit),
mustAdaptToExpr(condition),
mustAdaptToExpr(step),
mustAdaptToExpr(result),
),
)
}
// Ident creates an identifier Expr value.
func (ah *adaptingHelper) Ident(name string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewIdent(name))
}
// AccuIdent returns an accumulator identifier for use with comprehension results.
func (ah *adaptingHelper) AccuIdent() *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewAccuIdent())
}
// GlobalCall creates a function call Expr value for a global (free) function.
func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...))
}
// ReceiverCall creates a function call Expr value for a receiver-style function.
func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...))
}
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field))
}
// Select create a field traversal Expr value.
func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewSelect(op, field))
}
// OffsetLocation returns the Location of the expression identifier.
func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location {
return ah.modernHelper.OffsetLocation(exprID)
}
// NewError associates an error message with a given expression id.
func (ah *adaptingHelper) NewError(exprID int64, message string) *Error {
return ah.modernHelper.NewError(exprID, message)
}
func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr {
adapted := make([]ast.Expr, len(exprs))
for i, e := range exprs {
adapted[i] = mustAdaptToExpr(e)
}
return adapted
}
func mustAdaptToExpr(e *exprpb.Expr) ast.Expr {
out, _ := adaptToExpr(e)
return out
}
func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ProtoToExpr(e)
if err != nil {
return nil, wrapErr(e.GetId(), "proto conversion failure", err)
}
return out, nil
}
func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr {
out, _ := ast.ProtoToEntryExpr(e)
return out
}
func mustAdaptToProto(e ast.Expr) *exprpb.Expr {
out, _ := adaptToProto(e)
return out
}
func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ExprToProto(e)
if err != nil {
return nil, wrapErr(e.ID(), "expr conversion failure", err)
}
return out, nil
}
func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry {
out, _ := ast.EntryExprToProto(e)
return out
}
func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) {
ah, ok := meh.(*adaptingHelper)
if !ok {
return nil, common.NewError(0,
fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh),
common.NoLocation)
}
return ah.modernHelper, nil
}

509
vendor/github.com/google/cel-go/cel/optimizer.go generated vendored Normal file
View File

@ -0,0 +1,509 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order.
//
// The static optimizer normalizes expression ids and type-checking run between optimization
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
}
// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
}
}
// Optimize applies a sequence of optimizations to an Ast within a given environment.
//
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)
ids := newIDGenerator(ast.MaxID(a.impl))
// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
baseFac := ast.NewExprFactory()
exprFac := &optimizerExprFactory{
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
}
ctx := &OptimizerContext{
optimizerExprFactory: exprFac,
Env: env,
Issues: issues,
}
// Apply the optimizations sequentially.
for _, o := range opt.optimizers {
optimized = o.Optimize(ctx, optimized)
if issues.Err() != nil {
return nil, issues
}
// Normalize expression id metadata including coordination with macro call metadata.
freshIDGen := newIDGenerator(0)
info := optimized.SourceInfo()
expr := optimized.Expr()
normalizeIDs(freshIDGen.renumberStable, expr, info)
cleanupMacroRefs(expr, info)
// Recheck the updated expression for any possible type-agreement or validation errors.
parsed := &Ast{
source: a.Source(),
impl: ast.NewAST(expr, info)}
checked, iss := ctx.Check(parsed)
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
}
// Return the optimized result.
return &Ast{
source: a.Source(),
impl: optimized,
}, nil
}
// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)
if len(info.MacroCalls()) == 0 {
return
}
// First, update the macro call ids themselves.
callIDMap := map[int64]int64{}
for id := range info.MacroCalls() {
callIDMap[id] = idGen(id)
}
// Then update the macro call definitions which refer to these ids, but
// ensure that the updates don't collide and remove macro entries which haven't
// been visited / updated yet.
type macroUpdate struct {
id int64
call ast.Expr
}
macroUpdates := []macroUpdate{}
for oldID, newID := range callIDMap {
call, found := info.GetMacroCall(oldID)
if !found {
continue
}
call.RenumberIDs(idGen)
macroUpdates = append(macroUpdates, macroUpdate{id: newID, call: call})
info.ClearMacroCall(oldID)
}
for _, u := range macroUpdates {
info.SetMacroCall(u.id, u.call)
}
}
func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
if len(info.MacroCalls()) == 0 {
return
}
// Sanitize the macro call references once the optimized expression has been computed
// and the ids normalized between the expression and the macros.
exprRefMap := make(map[int64]struct{})
ast.PostOrderVisit(expr, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
// Update the macro call id references to ensure that macro pointers are
// updated consistently across macros.
for _, call := range info.MacroCalls() {
ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
}
for id := range info.MacroCalls() {
if _, found := exprRefMap[id]; !found {
info.ClearMacroCall(id)
}
}
}
// newIDGenerator ensures that new ids are only created the first time they are encountered.
func newIDGenerator(seed int64) *idGenerator {
return &idGenerator{
idMap: make(map[int64]int64),
seed: seed,
}
}
type idGenerator struct {
idMap map[int64]int64
seed int64
}
func (gen *idGenerator) nextID() int64 {
gen.seed++
return gen.seed
}
func (gen *idGenerator) renumberStable(id int64) int64 {
if id == 0 {
return 0
}
if newID, found := gen.idMap[id]; found {
return newID
}
nextID := gen.nextID()
gen.idMap[id] = nextID
return nextID
}
// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate
// subexpressions and report any errors encountered along the way. The context also embeds the
// optimizerExprFactory which can be used to generate new sub-expressions with expression ids
// consistent with the expectations of a parsed expression.
type OptimizerContext struct {
*Env
*optimizerExprFactory
*Issues
}
// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Optimize(*OptimizerContext, *ast.AST) *ast.AST
}
type optimizerExprFactory struct {
*idGenerator
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
}
// NewAST creates an AST from the current expression using the tracked source info which
// is modified and managed by the OptimizerContext.
func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST {
return ast.NewAST(expr, opt.sourceInfo)
}
// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the
// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions
// between copies.
//
// Use this method before attempting to merge the expression from AST into another.
func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) {
idGen := newIDGenerator(opt.nextID())
defer func() { opt.seed = idGen.nextID() }()
copyExpr := opt.fac.CopyExpr(a.Expr())
copyInfo := ast.CopySourceInfo(a.SourceInfo())
normalizeIDs(idGen.renumberStable, copyExpr, copyInfo)
return copyExpr, copyInfo
}
// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
copyExpr, copyInfo := opt.CopyAST(a)
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
return copyExpr
}
// ClearMacroCall clears the macro at the given expression id.
func (opt *optimizerExprFactory) ClearMacroCall(id int64) {
opt.sourceInfo.ClearMacroCall(id)
}
// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info
// metadata.
func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}
// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
varID := opt.nextID()
remainingID := opt.nextID()
remaining = opt.fac.CopyExpr(remaining)
remaining.RenumberIDs(func(id int64) int64 {
if id == macroID {
return remainingID
}
return id
})
if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists {
opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
}
astExpr = opt.fac.NewComprehension(macroID,
opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}),
"#unused",
varName,
opt.fac.CopyExpr(varInit),
opt.fac.NewLiteral(opt.nextID(), types.False),
opt.fac.NewIdent(varID, varName),
remaining)
macroExpr = opt.fac.NewMemberCall(0, "bind",
opt.fac.NewIdent(opt.nextID(), "cel"),
opt.fac.NewIdent(varID, varName),
opt.fac.CopyExpr(varInit),
opt.fac.CopyExpr(remaining))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewCall creates a global function call invocation expression.
//
// Example:
//
// countByField(list, fieldName)
// - function: countByField
// - args: [list, fieldName]
func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr {
return opt.fac.NewCall(opt.nextID(), function, args...)
}
// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call.
//
// Example:
//
// list.countByField(fieldName)
// - function: countByField
// - target: list
// - args: [fieldName]
func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return opt.fac.NewMemberCall(opt.nextID(), function, target, args...)
}
// NewIdent creates a new identifier expression.
//
// Examples:
//
// - simple_var_name
// - qualified.subpackage.var_name
func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr {
return opt.fac.NewIdent(opt.nextID(), name)
}
// NewLiteral creates a new literal expression value.
//
// The range of valid values for a literal generated during optimization is different than for expressions
// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can
// be converted back to a literal-like form.
func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr {
return opt.fac.NewLiteral(opt.nextID(), value)
}
// NewList creates a list expression with a set of optional indices.
//
// Examples:
//
// [a, b]
// - elems: [a, b]
// - optIndices: []
//
// [a, ?b, ?c]
// - elems: [a, b, c]
// - optIndices: [1, 2]
func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr {
return opt.fac.NewList(opt.nextID(), elems, optIndices)
}
// NewMap creates a map from a set of entry expressions which contain a key and value expression.
func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr {
return opt.fac.NewMap(opt.nextID(), entries)
}
// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the
// entry is optional.
//
// Examples:
//
// {a: b}
// - key: a
// - value: b
// - optional: false
//
// {?a: ?b}
// - key: a
// - value: b
// - optional: true
func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional)
}
// NewHasMacro generates a test-only select expression to be included within an AST and an unexpanded
// has() macro call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewHasMacro(macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) {
sel := s.AsSelect()
astExpr = opt.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName())
macroExpr = opt.fac.NewCall(0, "has",
opt.NewSelect(opt.fac.CopyExpr(sel.Operand()), sel.FieldName()))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewSelect creates a select expression where a field value is selected from an operand.
//
// Example:
//
// msg.field_name
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr {
return opt.fac.NewSelect(opt.nextID(), operand, field)
}
// NewStruct creates a new typed struct value with an set of field initializations.
//
// Example:
//
// pkg.TypeName{field: value}
// - typeName: pkg.TypeName
// - fields: [{field: value}]
func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr {
return opt.fac.NewStruct(opt.nextID(), typeName, fields)
}
// NewStructField creates a struct field initialization.
//
// Examples:
//
// {count: 3u}
// - field: count
// - value: 3u
// - optional: false
//
// {?count: x}
// - field: count
// - value: x
// - optional: true
func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewStructField(opt.nextID(), field, value, isOptional)
}
// UpdateExpr updates the target expression with the updated content while preserving macro metadata.
//
// There are four scenarios during the update to consider:
// 1. target is not macro, updated is not macro
// 2. target is macro, updated is not macro
// 3. target is macro, updated is macro
// 4. target is not macro, updated is macro
//
// When the target is a macro already, it may either be updated to a new macro function
// body if the update is also a macro, or it may be removed altogether if the update is
// a macro.
//
// When the update is a macro, then the target references within other macros must be
// updated to point to the new updated macro. Otherwise, other macros which pointed to
// the target body must be replaced with copies of the updated expression body.
func (opt *optimizerExprFactory) UpdateExpr(target, updated ast.Expr) {
// Update the expression
target.SetKindCase(updated)
// Early return if there's no macros present sa the source info reflects the
// macro set from the target and updated expressions.
if len(opt.sourceInfo.MacroCalls()) == 0 {
return
}
// Determine whether the target expression was a macro.
_, targetIsMacro := opt.sourceInfo.GetMacroCall(target.ID())
// Determine whether the updated expression was a macro.
updatedMacro, updatedIsMacro := opt.sourceInfo.GetMacroCall(updated.ID())
if updatedIsMacro {
// If the updated call was a macro, then updated id maps to target id,
// and the updated macro moves into the target id slot.
opt.sourceInfo.ClearMacroCall(updated.ID())
opt.sourceInfo.SetMacroCall(target.ID(), updatedMacro)
} else if targetIsMacro {
// Otherwise if the target expr was a macro, but is no longer, clear
// the macro reference.
opt.sourceInfo.ClearMacroCall(target.ID())
}
// Punch holes in the updated value where macros references exist.
macroExpr := opt.fac.CopyExpr(target)
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
// Update any references to the expression within a macro
macroVisitor := ast.NewExprVisitor(func(call ast.Expr) {
// Update the target expression to point to the macro expression which
// will be empty if the updated expression was a macro.
if call.ID() == target.ID() {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Update the macro call expression if it refers to the updated expression
// id which has since been remapped to the target id.
if call.ID() == updated.ID() {
// Either ensure the expression is a macro reference or a populated with
// the relevant sub-expression if the updated expr was not a macro.
if updatedIsMacro {
call.SetKindCase(nil)
} else {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Since SetKindCase does not renumber the id, ensure the references to
// the old 'updated' id are mapped to the target id.
call.RenumberIDs(func(id int64) int64 {
if id == updated.ID() {
return target.ID()
}
return id
})
}
})
for _, call := range opt.sourceInfo.MacroCalls() {
ast.PostOrderVisit(call, macroVisitor)
}
}
func (opt *optimizerExprFactory) sanitizeMacro(macroID int64, macroExpr ast.Expr) {
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
}

View File

@ -448,6 +448,8 @@ const (
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.StringsValidateFormatCalls() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

View File

@ -19,7 +19,6 @@ import (
"fmt"
"sync"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@ -152,7 +151,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()
@ -213,34 +212,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@ -274,33 +245,16 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}
return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}
func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
if err != nil {
return nil, err
}
@ -580,8 +534,6 @@ func (p *evalActivationPool) Put(value any) {
}
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

View File

@ -21,8 +21,6 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@ -69,7 +67,7 @@ type ASTValidator interface {
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.NavigableExpr) error
type argChecker func(env *Env, call, arg ast.Expr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
@ -203,8 +201,8 @@ func (v formatValidator) Name() string {
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
@ -221,8 +219,8 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST,
}
}
func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
func evalCall(env *Env, call, arg ast.Expr) error {
ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
prg, err := env.Program(ast)
if err != nil {
return err
@ -231,7 +229,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error {
return err
}
func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
func compileRegex(_ *Env, _, arg ast.Expr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
@ -244,25 +242,14 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
root := ast.NavigateAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
@ -273,7 +260,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
et := a.GetType(e.ID())
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
@ -296,9 +283,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
mapEntry := e.AsMapEntry()
key, val := mapEntry.Key(), mapEntry.Value()
kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
if mapEntry.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
@ -316,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
@ -325,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
parent, found = parent.Parent()
}
return false
}
@ -353,8 +340,8 @@ func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return