// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package interpreter import ( "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" structpb "google.golang.org/protobuf/types/known/structpb" ) type astPruner struct { expr *exprpb.Expr macroCalls map[int64]*exprpb.Expr state EvalState nextExprID int64 } // TODO Consider having a separate walk of the AST that finds common // subexpressions. This can be called before or after constant folding to find // common subexpressions. // PruneAst prunes the given AST based on the given EvalState and generates a new AST. // Given AST is copied on write and a new AST is returned. // Couple of typical use cases this interface would be: // // A) // 1) Evaluate expr with some unknowns, // 2) If result is unknown: // // a) PruneAst // b) Goto 1 // // Functional call results which are known would be effectively cached across // iterations. // // B) // 1) Compile the expression (maybe via a service and maybe after checking a // // compiled expression does not exists in local cache) // // 2) Prepare the environment and the interpreter. Activation might be empty. // 3) Eval the expression. This might return unknown or error or a concrete // // value. // // 4) PruneAst // 4) Maybe cache the expression // This is effectively constant folding the expression. How the environment is // prepared in step 2 is flexible. For example, If the caller caches the // compiled and constant folded expressions, but is not willing to constant // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { pruneState := NewEvalState() for _, id := range state.IDs() { v, _ := state.Value(id) pruneState.SetValue(id, v) } pruner := &astPruner{ expr: expr, macroCalls: macroCalls, state: pruneState, nextExprID: getMaxID(expr)} newExpr, _ := pruner.maybePrune(expr) return &exprpb.ParsedExpr{ Expr: newExpr, SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls}, } } func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_ConstExpr{ ConstExpr: val, }, } } func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { switch v := val.(type) { case types.Bool: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true case types.Bytes: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true case types.Double: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true case types.Duration: p.state.SetValue(id, val) durationString := string(v.ConvertToType(types.StringType).(types.String)) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_CallExpr{ CallExpr: &exprpb.Expr_Call{ Function: overloads.TypeConvertDuration, Args: []*exprpb.Expr{ p.createLiteral(p.nextID(), &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}), }, }, }, }, true case types.Int: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true case types.Uint: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true case types.String: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true case types.Null: p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true } // Attempt to build a list literal. if list, isList := val.(traits.Lister); isList { sz := list.Size().(types.Int) elemExprs := make([]*exprpb.Expr, sz) for i := types.Int(0); i < sz; i++ { elem := list.Get(i) if types.IsUnknownOrError(elem) { return nil, false } elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem) if !ok { return nil, false } elemExprs[i] = elemExpr } p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_ListExpr{ ListExpr: &exprpb.Expr_CreateList{ Elements: elemExprs, }, }, }, true } // Create a map literal if possible. if mp, isMap := val.(traits.Mapper); isMap { it := mp.Iterator() entries := make([]*exprpb.Expr_CreateStruct_Entry, mp.Size().(types.Int)) i := 0 for it.HasNext() != types.False { key := it.Next() val := mp.Get(key) if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) { return nil, false } keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key) if !ok { return nil, false } valExpr, ok := p.maybeCreateLiteral(p.nextID(), val) if !ok { return nil, false } entry := &exprpb.Expr_CreateStruct_Entry{ Id: p.nextID(), KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{ MapKey: keyExpr, }, Value: valExpr, } entries[i] = entry i++ } p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_StructExpr{ StructExpr: &exprpb.Expr_CreateStruct{ Entries: entries, }, }, }, true } // TODO(issues/377) To construct message literals, the type provider will need to support // the enumeration the fields for a given message. return nil, false } func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { elemVal, found := p.value(elem.GetId()) if found && elemVal.Type() == types.OptionalType { opt := elemVal.(*types.Optional) if !opt.HasValue() { return nil, true } if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned { return newElem, true } } return elem, false } func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { // elem in list call := node.GetCallExpr() val, exists := p.maybeValue(call.GetArgs()[1].GetId()) if !exists { return nil, false } if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { return p.maybeCreateLiteral(node.GetId(), types.False) } return nil, false } func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() arg := call.GetArgs()[0] val, exists := p.maybeValue(arg.GetId()) if !exists { return nil, false } if b, ok := val.(types.Bool); ok { return p.maybeCreateLiteral(node.GetId(), !b) } return nil, false } func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { if v == types.True { return p.maybeCreateLiteral(node.GetId(), types.True) } return call.GetArgs()[1], true } if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { if v == types.True { return p.maybeCreateLiteral(node.GetId(), types.True) } return call.GetArgs()[0], true } return nil, false } func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { if v == types.False { return p.maybeCreateLiteral(node.GetId(), types.False) } return call.GetArgs()[1], true } if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { if v == types.False { return p.maybeCreateLiteral(node.GetId(), types.False) } return call.GetArgs()[0], true } return nil, false } func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() cond, exists := p.maybeValue(call.GetArgs()[0].GetId()) if !exists { return nil, false } if cond.Value().(bool) { return call.GetArgs()[1], true } return call.GetArgs()[2], true } func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { if _, exists := p.value(node.GetId()); !exists { return nil, false } call := node.GetCallExpr() if call.Function == operators.LogicalOr { return p.maybePruneOr(node) } if call.Function == operators.LogicalAnd { return p.maybePruneAnd(node) } if call.Function == operators.Conditional { return p.maybePruneConditional(node) } if call.Function == operators.In { return p.maybePruneIn(node) } if call.Function == operators.LogicalNot { return p.maybePruneLogicalNot(node) } return nil, false } func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) { return p.prune(node) } func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if node == nil { return node, false } val, valueExists := p.maybeValue(node.GetId()) if valueExists { if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { delete(p.macroCalls, node.GetId()) return newNode, true } } if macro, found := p.macroCalls[node.GetId()]; found { // Ensure that intermediate values for the comprehension are cleared during pruning compre := node.GetComprehensionExpr() if compre != nil { visit(macro, clearIterVarVisitor(compre.IterVar, p.state)) } // prune the expression in terms of the macro call instead of the expanded form. if newMacro, pruned := p.prune(macro); pruned { p.macroCalls[node.GetId()] = newMacro } } // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. switch node.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned { return &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_SelectExpr{ SelectExpr: &exprpb.Expr_Select{ Operand: operand, Field: node.GetSelectExpr().GetField(), TestOnly: node.GetSelectExpr().GetTestOnly(), }, }, }, true } case *exprpb.Expr_CallExpr: var prunedCall bool call := node.GetCallExpr() args := call.GetArgs() newArgs := make([]*exprpb.Expr, len(args)) newCall := &exprpb.Expr_Call{ Function: call.GetFunction(), Target: call.GetTarget(), Args: newArgs, } for i, arg := range args { newArgs[i] = arg if newArg, prunedArg := p.maybePrune(arg); prunedArg { prunedCall = true newArgs[i] = newArg } } if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget { prunedCall = true newCall.Target = newTarget } newNode := &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_CallExpr{ CallExpr: newCall, }, } if newExpr, pruned := p.maybePruneFunction(newNode); pruned { newExpr, _ = p.maybePrune(newExpr) return newExpr, true } if prunedCall { return newNode, true } case *exprpb.Expr_ListExpr: elems := node.GetListExpr().GetElements() optIndices := node.GetListExpr().GetOptionalIndices() optIndexMap := map[int32]bool{} for _, i := range optIndices { optIndexMap[i] = true } newOptIndexMap := make(map[int32]bool, len(optIndexMap)) newElems := make([]*exprpb.Expr, 0, len(elems)) var prunedList bool prunedIdx := 0 for i, elem := range elems { _, isOpt := optIndexMap[int32(i)] if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { prunedList = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ } continue } newOptIndexMap[int32(prunedIdx)] = true } if newElem, prunedElem := p.maybePrune(elem); prunedElem { newElems = append(newElems, newElem) prunedList = true } else { newElems = append(newElems, elem) } prunedIdx++ } optIndices = make([]int32, len(newOptIndexMap)) idx := 0 for i := range newOptIndexMap { optIndices[idx] = i idx++ } if prunedList { return &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_ListExpr{ ListExpr: &exprpb.Expr_CreateList{ Elements: newElems, OptionalIndices: optIndices, }, }, }, true } case *exprpb.Expr_StructExpr: var prunedStruct bool entries := node.GetStructExpr().GetEntries() messageType := node.GetStructExpr().GetMessageName() newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries)) for i, entry := range entries { newEntries[i] = entry newKey, prunedKey := p.maybePrune(entry.GetMapKey()) newValue, prunedValue := p.maybePrune(entry.GetValue()) if !prunedKey && !prunedValue { continue } prunedStruct = true newEntry := &exprpb.Expr_CreateStruct_Entry{ Value: newValue, } if messageType != "" { newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{ FieldKey: entry.GetFieldKey(), } } else { newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{ MapKey: newKey, } } newEntry.OptionalEntry = entry.GetOptionalEntry() newEntries[i] = newEntry } if prunedStruct { return &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_StructExpr{ StructExpr: &exprpb.Expr_CreateStruct{ MessageName: messageType, Entries: newEntries, }, }, }, true } case *exprpb.Expr_ComprehensionExpr: compre := node.GetComprehensionExpr() // Only the range of the comprehension is pruned since the state tracking only records // the last iteration of the comprehension and not each step in the evaluation which // means that the any residuals computed in between might be inaccurate. if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned { return &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_ComprehensionExpr{ ComprehensionExpr: &exprpb.Expr_Comprehension{ IterVar: compre.GetIterVar(), IterRange: newRange, AccuVar: compre.GetAccuVar(), AccuInit: compre.GetAccuInit(), LoopCondition: compre.GetLoopCondition(), LoopStep: compre.GetLoopStep(), Result: compre.GetResult(), }, }, }, true } } return node, false } func (p *astPruner) value(id int64) (ref.Val, bool) { val, found := p.state.Value(id) return val, (found && val != nil) } func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { val, found := p.value(id) if !found || types.IsUnknownOrError(val) { return nil, false } return val, true } func (p *astPruner) nextID() int64 { next := p.nextExprID p.nextExprID++ return next } type astVisitor struct { // visitEntry is called on every expr node, including those within a map/struct entry. visitExpr func(expr *exprpb.Expr) // visitEntry is called before entering the key, value of a map/struct entry. visitEntry func(entry *exprpb.Expr_CreateStruct_Entry) } func getMaxID(expr *exprpb.Expr) int64 { maxID := int64(1) visit(expr, maxIDVisitor(&maxID)) return maxID } func clearIterVarVisitor(varName string, state EvalState) astVisitor { return astVisitor{ visitExpr: func(e *exprpb.Expr) { ident := e.GetIdentExpr() if ident != nil && ident.GetName() == varName { state.SetValue(e.GetId(), nil) } }, } } func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ visitExpr: func(e *exprpb.Expr) { if e.GetId() >= *maxID { *maxID = e.GetId() + 1 } }, visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) { if e.GetId() >= *maxID { *maxID = e.GetId() + 1 } }, } } func visit(expr *exprpb.Expr, visitor astVisitor) { exprs := []*exprpb.Expr{expr} for len(exprs) != 0 { e := exprs[0] if visitor.visitExpr != nil { visitor.visitExpr(e) } exprs = exprs[1:] switch e.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: exprs = append(exprs, e.GetSelectExpr().GetOperand()) case *exprpb.Expr_CallExpr: call := e.GetCallExpr() if call.GetTarget() != nil { exprs = append(exprs, call.GetTarget()) } exprs = append(exprs, call.GetArgs()...) case *exprpb.Expr_ComprehensionExpr: compre := e.GetComprehensionExpr() exprs = append(exprs, compre.GetIterRange(), compre.GetAccuInit(), compre.GetLoopCondition(), compre.GetLoopStep(), compre.GetResult()) case *exprpb.Expr_ListExpr: list := e.GetListExpr() exprs = append(exprs, list.GetElements()...) case *exprpb.Expr_StructExpr: for _, entry := range e.GetStructExpr().GetEntries() { if visitor.visitEntry != nil { visitor.visitEntry(entry) } if entry.GetMapKey() != nil { exprs = append(exprs, entry.GetMapKey()) } exprs = append(exprs, entry.GetValue()) } } } }