// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package interpreter import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" ) type astPruner struct { ast.ExprFactory expr ast.Expr macroCalls map[int64]ast.Expr state EvalState nextExprID int64 } // TODO Consider having a separate walk of the AST that finds common // subexpressions. This can be called before or after constant folding to find // common subexpressions. // PruneAst prunes the given AST based on the given EvalState and generates a new AST. // Given AST is copied on write and a new AST is returned. // Couple of typical use cases this interface would be: // // A) // 1) Evaluate expr with some unknowns, // 2) If result is unknown: // // a) PruneAst // b) Goto 1 // // Functional call results which are known would be effectively cached across // iterations. // // B) // 1) Compile the expression (maybe via a service and maybe after checking a // // compiled expression does not exists in local cache) // // 2) Prepare the environment and the interpreter. Activation might be empty. // 3) Eval the expression. This might return unknown or error or a concrete // // value. // // 4) PruneAst // 4) Maybe cache the expression // This is effectively constant folding the expression. How the environment is // prepared in step 2 is flexible. For example, If the caller caches the // compiled and constant folded expressions, but is not willing to constant // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { pruneState := NewEvalState() for _, id := range state.IDs() { v, _ := state.Value(id) pruneState.SetValue(id, v) } pruner := &astPruner{ ExprFactory: ast.NewExprFactory(), expr: expr, macroCalls: macroCalls, state: pruneState, nextExprID: getMaxID(expr)} newExpr, _ := pruner.maybePrune(expr) newInfo := ast.NewSourceInfo(nil) for id, call := range pruner.macroCalls { newInfo.SetMacroCall(id, call) } return ast.NewAST(newExpr, newInfo) } func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { switch v := val.(type) { case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint: p.state.SetValue(id, val) return p.NewLiteral(id, val), true case types.Duration: p.state.SetValue(id, val) durationString := v.ConvertToType(types.StringType).(types.String) return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true case types.Timestamp: timestampString := v.ConvertToType(types.StringType).(types.String) return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true } // Attempt to build a list literal. if list, isList := val.(traits.Lister); isList { sz := list.Size().(types.Int) elemExprs := make([]ast.Expr, sz) for i := types.Int(0); i < sz; i++ { elem := list.Get(i) if types.IsUnknownOrError(elem) { return nil, false } elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem) if !ok { return nil, false } elemExprs[i] = elemExpr } p.state.SetValue(id, val) return p.NewList(id, elemExprs, []int32{}), true } // Create a map literal if possible. if mp, isMap := val.(traits.Mapper); isMap { it := mp.Iterator() entries := make([]ast.EntryExpr, mp.Size().(types.Int)) i := 0 for it.HasNext() != types.False { key := it.Next() val := mp.Get(key) if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) { return nil, false } keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key) if !ok { return nil, false } valExpr, ok := p.maybeCreateLiteral(p.nextID(), val) if !ok { return nil, false } entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) entries[i] = entry i++ } p.state.SetValue(id, val) return p.NewMap(id, entries), true } // TODO(issues/377) To construct message literals, the type provider will need to support // the enumeration the fields for a given message. return nil, false } func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { elemVal, found := p.value(elem.ID()) if found && elemVal.Type() == types.OptionalType { opt := elemVal.(*types.Optional) if !opt.HasValue() { return nil, true } if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { return newElem, true } } return elem, false } func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { // elem in list call := node.AsCall() val, exists := p.maybeValue(call.Args()[1].ID()) if !exists { return nil, false } if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { return p.maybeCreateLiteral(node.ID(), types.False) } return nil, false } func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { call := node.AsCall() arg := call.Args()[0] val, exists := p.maybeValue(arg.ID()) if !exists { return nil, false } if b, ok := val.(types.Bool); ok { return p.maybeCreateLiteral(node.ID(), !b) } return nil, false } func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.True { return p.maybeCreateLiteral(node.ID(), types.True) } return call.Args()[1], true } if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.True { return p.maybeCreateLiteral(node.ID(), types.True) } return call.Args()[0], true } return nil, false } func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.False { return p.maybeCreateLiteral(node.ID(), types.False) } return call.Args()[1], true } if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.False { return p.maybeCreateLiteral(node.ID(), types.False) } return call.Args()[0], true } return nil, false } func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { call := node.AsCall() cond, exists := p.maybeValue(call.Args()[0].ID()) if !exists { return nil, false } if cond.Value().(bool) { return call.Args()[1], true } return call.Args()[2], true } func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { if _, exists := p.value(node.ID()); !exists { return nil, false } call := node.AsCall() if call.FunctionName() == operators.LogicalOr { return p.maybePruneOr(node) } if call.FunctionName() == operators.LogicalAnd { return p.maybePruneAnd(node) } if call.FunctionName() == operators.Conditional { return p.maybePruneConditional(node) } if call.FunctionName() == operators.In { return p.maybePruneIn(node) } if call.FunctionName() == operators.LogicalNot { return p.maybePruneLogicalNot(node) } return nil, false } func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { return p.prune(node) } func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { if node == nil { return node, false } val, valueExists := p.maybeValue(node.ID()) if valueExists { if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { delete(p.macroCalls, node.ID()) return newNode, true } } if macro, found := p.macroCalls[node.ID()]; found { // Ensure that intermediate values for the comprehension are cleared during pruning if node.Kind() == ast.ComprehensionKind { compre := node.AsComprehension() visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) } // prune the expression in terms of the macro call instead of the expanded form. if newMacro, pruned := p.prune(macro); pruned { p.macroCalls[node.ID()] = newMacro } } // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. switch node.Kind() { case ast.SelectKind: sel := node.AsSelect() if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { if sel.IsTestOnly() { return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true } return p.NewSelect(node.ID(), operand, sel.FieldName()), true } case ast.CallKind: argsPruned := false call := node.AsCall() args := call.Args() newArgs := make([]ast.Expr, len(args)) for i, a := range args { newArgs[i] = a if arg, isPruned := p.maybePrune(a); isPruned { argsPruned = true newArgs[i] = arg } } if !call.IsMemberFunction() { newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { return prunedCall, true } return newCall, argsPruned } newTarget := call.Target() targetPruned := false if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { targetPruned = true newTarget = prunedTarget } newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { return prunedCall, true } return newCall, targetPruned || argsPruned case ast.ListKind: l := node.AsList() elems := l.Elements() optIndices := l.OptionalIndices() optIndexMap := map[int32]bool{} for _, i := range optIndices { optIndexMap[i] = true } newOptIndexMap := make(map[int32]bool, len(optIndexMap)) newElems := make([]ast.Expr, 0, len(elems)) var listPruned bool prunedIdx := 0 for i, elem := range elems { _, isOpt := optIndexMap[int32(i)] if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { listPruned = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ } continue } newOptIndexMap[int32(prunedIdx)] = true } if newElem, prunedElem := p.maybePrune(elem); prunedElem { newElems = append(newElems, newElem) listPruned = true } else { newElems = append(newElems, elem) } prunedIdx++ } optIndices = make([]int32, len(newOptIndexMap)) idx := 0 for i := range newOptIndexMap { optIndices[idx] = i idx++ } if listPruned { return p.NewList(node.ID(), newElems, optIndices), true } case ast.MapKind: var mapPruned bool m := node.AsMap() entries := m.Entries() newEntries := make([]ast.EntryExpr, len(entries)) for i, entry := range entries { newEntries[i] = entry e := entry.AsMapEntry() newKey, keyPruned := p.maybePrune(e.Key()) newValue, valuePruned := p.maybePrune(e.Value()) if !keyPruned && !valuePruned { continue } mapPruned = true newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) newEntries[i] = newEntry } if mapPruned { return p.NewMap(node.ID(), newEntries), true } case ast.StructKind: var structPruned bool obj := node.AsStruct() fields := obj.Fields() newFields := make([]ast.EntryExpr, len(fields)) for i, field := range fields { newFields[i] = field f := field.AsStructField() newValue, prunedValue := p.maybePrune(f.Value()) if !prunedValue { continue } structPruned = true newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) newFields[i] = newEntry } if structPruned { return p.NewStruct(node.ID(), obj.TypeName(), newFields), true } case ast.ComprehensionKind: compre := node.AsComprehension() // Only the range of the comprehension is pruned since the state tracking only records // the last iteration of the comprehension and not each step in the evaluation which // means that the any residuals computed in between might be inaccurate. if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { return p.NewComprehension( node.ID(), newRange, compre.IterVar(), compre.AccuVar(), compre.AccuInit(), compre.LoopCondition(), compre.LoopStep(), compre.Result(), ), true } } return node, false } func (p *astPruner) value(id int64) (ref.Val, bool) { val, found := p.state.Value(id) return val, (found && val != nil) } func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { val, found := p.value(id) if !found || types.IsUnknownOrError(val) { return nil, false } return val, true } func (p *astPruner) nextID() int64 { next := p.nextExprID p.nextExprID++ return next } type astVisitor struct { // visitEntry is called on every expr node, including those within a map/struct entry. visitExpr func(expr ast.Expr) // visitEntry is called before entering the key, value of a map/struct entry. visitEntry func(entry ast.EntryExpr) } func getMaxID(expr ast.Expr) int64 { maxID := int64(1) visit(expr, maxIDVisitor(&maxID)) return maxID } func clearIterVarVisitor(varName string, state EvalState) astVisitor { return astVisitor{ visitExpr: func(e ast.Expr) { if e.Kind() == ast.IdentKind && e.AsIdent() == varName { state.SetValue(e.ID(), nil) } }, } } func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ visitExpr: func(e ast.Expr) { if e.ID() >= *maxID { *maxID = e.ID() + 1 } }, visitEntry: func(e ast.EntryExpr) { if e.ID() >= *maxID { *maxID = e.ID() + 1 } }, } } func visit(expr ast.Expr, visitor astVisitor) { exprs := []ast.Expr{expr} for len(exprs) != 0 { e := exprs[0] if visitor.visitExpr != nil { visitor.visitExpr(e) } exprs = exprs[1:] switch e.Kind() { case ast.SelectKind: exprs = append(exprs, e.AsSelect().Operand()) case ast.CallKind: call := e.AsCall() if call.Target() != nil { exprs = append(exprs, call.Target()) } exprs = append(exprs, call.Args()...) case ast.ComprehensionKind: compre := e.AsComprehension() exprs = append(exprs, compre.IterRange(), compre.AccuInit(), compre.LoopCondition(), compre.LoopStep(), compre.Result()) case ast.ListKind: list := e.AsList() exprs = append(exprs, list.Elements()...) case ast.MapKind: for _, entry := range e.AsMap().Entries() { e := entry.AsMapEntry() if visitor.visitEntry != nil { visitor.visitEntry(entry) } exprs = append(exprs, e.Key()) exprs = append(exprs, e.Value()) } case ast.StructKind: for _, entry := range e.AsStruct().Fields() { f := entry.AsStructField() if visitor.visitEntry != nil { visitor.visitEntry(entry) } exprs = append(exprs, f.Value()) } } } }