ceph-csi/vendor/github.com/google/cel-go/cel/folding.go

560 lines
17 KiB
Go
Raw Permalink Normal View History

// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// ConstantFoldingOption defines a functional option for configuring constant folding.
type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)
// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
//
// Defaults to 100 if not set.
func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
opt.maxFoldIterations = limit
return opt, nil
}
}
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
folder := &constantFoldingOptimizer{
maxFoldIterations: defaultMaxConstantFoldIterations,
}
var err error
for _, o := range opts {
folder, err = o(folder)
if err != nil {
return nil, err
}
}
return folder, nil
}
type constantFoldingOptimizer struct {
maxFoldIterations int
}
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
// select statements and then evaluates them and replaces the call site with the literal result.
//
// Note: only values which can be represented as literals in CEL syntax are supported.
func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
// If the expression could be folded because it's a non-strict call, and the
// branches are pruned, continue to the next fold.
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
// target comprehension node with the literal value if the evaluation succeeds.
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
tryFold(ctx, a, compre)
}
// If the output is a list, map, or struct which contains optional entries, then prune it
// to make sure that the optionals, if resolved, do not surface in the output literal.
pruneOptionalElements(ctx, root)
// Ensure that all intermediate values in the folded expression can be represented as valid
// CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
// to avoid extra allocations during this final pass through the AST.
ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() != ast.LiteralKind {
return
}
val := e.AsLiteral()
adapted, err := adaptLiteral(ctx, val)
if err != nil {
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
ctx.UpdateExpr(e, adapted)
}))
return a
}
// tryFold attempts to evaluate a sub-expression to a literal.
//
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
// the method will return an error.
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// Assume all context is needed to evaluate the expression.
subAST := &Ast{
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
}
prg, err := ctx.Program(subAST)
if err != nil {
return err
}
out, _, err := prg.Eval(NoVars())
if err != nil {
return err
}
// Update the fold expression to be a literal.
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}
// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
// constant folding step reimplements a portion of the evaluator.
func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
call := expr.AsCall()
args := call.Args()
switch call.FunctionName() {
case operators.LogicalAnd, operators.LogicalOr:
return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
case operators.Conditional:
cond := args[0]
truthy := args[1]
falsy := args[2]
if cond.Kind() != ast.LiteralKind {
return false
}
if cond.AsLiteral() == types.True {
ctx.UpdateExpr(expr, truthy)
} else {
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
}
}
return false
}
func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
shortcircuit := types.False
skip := types.True
if function == operators.LogicalOr {
shortcircuit = types.True
skip = types.False
}
newArgs := []ast.Expr{}
for _, arg := range args {
if arg.Kind() != ast.LiteralKind {
newArgs = append(newArgs, arg)
continue
}
if arg.AsLiteral() == skip {
continue
}
if arg.AsLiteral() == shortcircuit {
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
ctx.UpdateExpr(expr, newArgs[0])
return true
}
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}
// pruneOptionalElements works from the bottom up to resolve optional elements within
// aggregate literals.
//
// Note, many aggregate literals will be resolved as arguments to functions or select
// statements, so this method exists to handle the case where the literal could not be
// fully resolved or exists outside of a call, select, or comprehension context.
func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher)
for _, lit := range aggregateLiterals {
switch lit.Kind() {
case ast.ListKind:
pruneOptionalListElements(ctx, lit)
case ast.MapKind:
pruneOptionalMapEntries(ctx, lit)
case ast.StructKind:
pruneOptionalStructFields(ctx, lit)
}
}
}
func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
l := e.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
if len(optIndices) == 0 {
return
}
updatedElems := []ast.Expr{}
updatedIndices := []int32{}
newOptIndex := -1
for _, e := range elems {
newOptIndex++
if !l.IsOptional(int32(newOptIndex)) {
updatedElems = append(updatedElems, e)
continue
}
if e.Kind() != ast.LiteralKind {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
optElemVal, ok := e.AsLiteral().(*types.Optional)
if !ok {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
if !optElemVal.HasValue() {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}
func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
m := e.AsMap()
entries := m.Entries()
updatedEntries := []ast.EntryExpr{}
modified := false
for _, e := range entries {
entry := e.AsMapEntry()
key := entry.Key()
val := entry.Value()
// If the entry is not optional, or the value-side of the optional hasn't
// been resolved to a literal, then preserve the entry as-is.
if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
updatedEntries = append(updatedEntries, e)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedEntries = append(updatedEntries, e)
continue
}
// When the key is not a literal, but the value is, then it needs to be
// restored to an optional value.
if key.Kind() != ast.LiteralKind {
undoOptVal, err := adaptLiteral(ctx, optElemVal)
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}
func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
s := e.AsStruct()
fields := s.Fields()
updatedFields := []ast.EntryExpr{}
modified := false
for _, f := range fields {
field := f.AsStructField()
val := field.Value()
if !field.IsOptional() || val.Kind() != ast.LiteralKind {
updatedFields = append(updatedFields, f)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedFields = append(updatedFields, f)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}
// adaptLiteral converts a runtime CEL value to its equivalent literal expression.
//
// For strongly typed values, the type-provider will be used to reconstruct the fields
// which are present in the literal and their equivalent initialization values.
func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
switch t := val.Type().(type) {
case *types.Type:
switch t {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
return ctx.NewLiteral(val), nil
case types.DurationType:
return ctx.NewCall(
overloads.TypeConvertDuration,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.TimestampType:
return ctx.NewCall(
overloads.TypeConvertTimestamp,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.OptionalType:
opt := val.(*types.Optional)
if !opt.HasValue() {
return ctx.NewCall("optional.none"), nil
}
target, err := adaptLiteral(ctx, opt.GetValue())
if err != nil {
return nil, err
}
return ctx.NewCall("optional.of", target), nil
case types.TypeType:
return ctx.NewIdent(val.(*types.Type).TypeName()), nil
case types.ListType:
l, ok := val.(traits.Lister)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
elems := make([]ast.Expr, l.Size().(types.Int))
idx := 0
it := l.Iterator()
for it.HasNext() == types.True {
elemVal := it.Next()
elemExpr, err := adaptLiteral(ctx, elemVal)
if err != nil {
return nil, err
}
elems[idx] = elemExpr
idx++
}
return ctx.NewList(elems, []int32{}), nil
case types.MapType:
m, ok := val.(traits.Mapper)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
entries := make([]ast.EntryExpr, m.Size().(types.Int))
idx := 0
it := m.Iterator()
for it.HasNext() == types.True {
keyVal := it.Next()
keyExpr, err := adaptLiteral(ctx, keyVal)
if err != nil {
return nil, err
}
valVal := m.Get(keyVal)
valExpr, err := adaptLiteral(ctx, valVal)
if err != nil {
return nil, err
}
entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false)
idx++
}
return ctx.NewMap(entries), nil
default:
provider := ctx.CELTypeProvider()
fields, found := provider.FindStructFieldNames(t.TypeName())
if !found {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
tester := val.(traits.FieldTester)
indexer := val.(traits.Indexer)
fieldInits := []ast.EntryExpr{}
for _, f := range fields {
field := types.String(f)
if tester.IsSet(field) != types.True {
continue
}
fieldVal := indexer.Get(field)
fieldExpr, err := adaptLiteral(ctx, fieldVal)
if err != nil {
return nil, err
}
fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false))
}
return ctx.NewStruct(t.TypeName(), fieldInits), nil
}
}
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
// constantExprMatcher matches calls, select statements, and comprehensions whose arguments
// are all constant scalar or aggregate literal values.
//
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
case ast.SelectKind:
sel := e.AsSelect() // guaranteed to be a navigable value
return constantMatcher(sel.Operand().(ast.NavigableExpr))
case ast.ComprehensionKind:
if isNestedComprehension(e) {
return false
}
vars := map[string]bool{}
constantExprs := true
visitor := ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() == ast.ComprehensionKind {
nested := e.AsComprehension()
vars[nested.AccuVar()] = true
vars[nested.IterVar()] = true
}
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
default:
return false
}
}
// constantCallMatcher identifies strict and non-strict calls which can be folded.
func constantCallMatcher(e ast.NavigableExpr) bool {
call := e.AsCall()
children := e.Children()
fnName := call.FunctionName()
if fnName == operators.LogicalAnd {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.LogicalOr {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.Conditional {
cond := children[0]
if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType {
return true
}
}
if fnName == operators.In {
haystack := children[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
return true
}
needle := children[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
return true
}
}
}
}
// convert all other calls with constant arguments
for _, child := range children {
if !constantMatcher(child) {
return false
}
}
return true
}
func isNestedComprehension(e ast.NavigableExpr) bool {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.ComprehensionKind {
return true
}
parent, found = parent.Parent()
}
return false
}
func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind
}
var (
constantMatcher = ast.ConstantValueMatcher()
)
const (
defaultMaxConstantFoldIterations = 100
)