ceph-csi/vendor/github.com/google/cel-go/interpreter/prune.go

544 lines
15 KiB
Go
Raw Normal View History

// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
type astPruner struct {
ast.ExprFactory
expr ast.Expr
macroCalls map[int64]ast.Expr
state EvalState
nextExprID int64
}
// TODO Consider having a separate walk of the AST that finds common
// subexpressions. This can be called before or after constant folding to find
// common subexpressions.
// PruneAst prunes the given AST based on the given EvalState and generates a new AST.
// Given AST is copied on write and a new AST is returned.
// Couple of typical use cases this interface would be:
//
// A)
// 1) Evaluate expr with some unknowns,
// 2) If result is unknown:
//
// a) PruneAst
// b) Goto 1
//
// Functional call results which are known would be effectively cached across
// iterations.
//
// B)
// 1) Compile the expression (maybe via a service and maybe after checking a
//
// compiled expression does not exists in local cache)
//
// 2) Prepare the environment and the interpreter. Activation might be empty.
// 3) Eval the expression. This might return unknown or error or a concrete
//
// value.
//
// 4) PruneAst
// 4) Maybe cache the expression
// This is effectively constant folding the expression. How the environment is
// prepared in step 2 is flexible. For example, If the caller caches the
// compiled and constant folded expressions, but is not willing to constant
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
ExprFactory: ast.NewExprFactory(),
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
newInfo := ast.NewSourceInfo(nil)
for id, call := range pruner.macroCalls {
newInfo.SetMacroCall(id, call)
}
return ast.NewAST(newExpr, newInfo)
}
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) {
switch v := val.(type) {
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint:
p.state.SetValue(id, val)
return p.NewLiteral(id, val), true
case types.Duration:
p.state.SetValue(id, val)
durationString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true
case types.Timestamp:
timestampString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true
}
// Attempt to build a list literal.
if list, isList := val.(traits.Lister); isList {
sz := list.Size().(types.Int)
elemExprs := make([]ast.Expr, sz)
for i := types.Int(0); i < sz; i++ {
elem := list.Get(i)
if types.IsUnknownOrError(elem) {
return nil, false
}
elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem)
if !ok {
return nil, false
}
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return p.NewList(id, elemExprs, []int32{}), true
}
// Create a map literal if possible.
if mp, isMap := val.(traits.Mapper); isMap {
it := mp.Iterator()
entries := make([]ast.EntryExpr, mp.Size().(types.Int))
i := 0
for it.HasNext() != types.False {
key := it.Next()
val := mp.Get(key)
if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) {
return nil, false
}
keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key)
if !ok {
return nil, false
}
valExpr, ok := p.maybeCreateLiteral(p.nextID(), val)
if !ok {
return nil, false
}
entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false)
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return p.NewMap(id, entries), true
}
// TODO(issues/377) To construct message literals, the type provider will need to support
// the enumeration the fields for a given message.
return nil, false
}
func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) {
elemVal, found := p.value(elem.ID())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}
func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) {
// elem in list
call := node.AsCall()
val, exists := p.maybeValue(call.Args()[1].ID())
if !exists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.ID(), types.False)
}
return nil, false
}
func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
arg := call.Args()[0]
val, exists := p.maybeValue(arg.ID())
if !exists {
return nil, false
}
if b, ok := val.(types.Bool); ok {
return p.maybeCreateLiteral(node.ID(), !b)
}
return nil, false
}
func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
cond, exists := p.maybeValue(call.Args()[0].ID())
if !exists {
return nil, false
}
if cond.Value().(bool) {
return call.Args()[1], true
}
return call.Args()[2], true
}
func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) {
if _, exists := p.value(node.ID()); !exists {
return nil, false
}
call := node.AsCall()
if call.FunctionName() == operators.LogicalOr {
return p.maybePruneOr(node)
}
if call.FunctionName() == operators.LogicalAnd {
return p.maybePruneAnd(node)
}
if call.FunctionName() == operators.Conditional {
return p.maybePruneConditional(node)
}
if call.FunctionName() == operators.In {
return p.maybePruneIn(node)
}
if call.FunctionName() == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}
func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) {
return p.prune(node)
}
func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.maybeValue(node.ID())
if valueExists {
if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok {
delete(p.macroCalls, node.ID())
return newNode, true
}
}
if macro, found := p.macroCalls[node.ID()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
if node.Kind() == ast.ComprehensionKind {
compre := node.AsComprehension()
visit(macro, clearIterVarVisitor(compre.IterVar(), p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.ID()] = newMacro
}
}
// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
// more.
switch node.Kind() {
case ast.SelectKind:
sel := node.AsSelect()
if operand, isPruned := p.maybePrune(sel.Operand()); isPruned {
if sel.IsTestOnly() {
return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true
}
return p.NewSelect(node.ID(), operand, sel.FieldName()), true
}
case ast.CallKind:
argsPruned := false
call := node.AsCall()
args := call.Args()
newArgs := make([]ast.Expr, len(args))
for i, a := range args {
newArgs[i] = a
if arg, isPruned := p.maybePrune(a); isPruned {
argsPruned = true
newArgs[i] = arg
}
}
if !call.IsMemberFunction() {
newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
return newCall, argsPruned
}
newTarget := call.Target()
targetPruned := false
if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned {
targetPruned = true
newTarget = prunedTarget
}
newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
return newCall, targetPruned || argsPruned
case ast.ListKind:
l := node.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]ast.Expr, 0, len(elems))
var listPruned bool
prunedIdx := 0
for i, elem := range elems {
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
listPruned = true
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
}
continue
}
newOptIndexMap[int32(prunedIdx)] = true
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems = append(newElems, newElem)
listPruned = true
} else {
newElems = append(newElems, elem)
}
prunedIdx++
}
optIndices = make([]int32, len(newOptIndexMap))
idx := 0
for i := range newOptIndexMap {
optIndices[idx] = i
idx++
}
if listPruned {
return p.NewList(node.ID(), newElems, optIndices), true
}
case ast.MapKind:
var mapPruned bool
m := node.AsMap()
entries := m.Entries()
newEntries := make([]ast.EntryExpr, len(entries))
for i, entry := range entries {
newEntries[i] = entry
e := entry.AsMapEntry()
newKey, keyPruned := p.maybePrune(e.Key())
newValue, valuePruned := p.maybePrune(e.Value())
if !keyPruned && !valuePruned {
continue
}
mapPruned = true
newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional())
newEntries[i] = newEntry
}
if mapPruned {
return p.NewMap(node.ID(), newEntries), true
}
case ast.StructKind:
var structPruned bool
obj := node.AsStruct()
fields := obj.Fields()
newFields := make([]ast.EntryExpr, len(fields))
for i, field := range fields {
newFields[i] = field
f := field.AsStructField()
newValue, prunedValue := p.maybePrune(f.Value())
if !prunedValue {
continue
}
structPruned = true
newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional())
newFields[i] = newEntry
}
if structPruned {
return p.NewStruct(node.ID(), obj.TypeName(), newFields), true
}
case ast.ComprehensionKind:
compre := node.AsComprehension()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.IterRange()); pruned {
return p.NewComprehension(
node.ID(),
newRange,
compre.IterVar(),
compre.AccuVar(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result(),
), true
}
}
return node, false
}
func (p *astPruner) value(id int64) (ref.Val, bool) {
val, found := p.state.Value(id)
return val, (found && val != nil)
}
func (p *astPruner) maybeValue(id int64) (ref.Val, bool) {
val, found := p.value(id)
if !found || types.IsUnknownOrError(val) {
return nil, false
}
return val, true
}
func (p *astPruner) nextID() int64 {
next := p.nextExprID
p.nextExprID++
return next
}
type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr ast.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry ast.EntryExpr)
}
func getMaxID(expr ast.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
}
func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e ast.Expr) {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
state.SetValue(e.ID(), nil)
}
},
}
}
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e ast.Expr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
visitEntry: func(e ast.EntryExpr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
}
}
func visit(expr ast.Expr, visitor astVisitor) {
exprs := []ast.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.Kind() {
case ast.SelectKind:
exprs = append(exprs, e.AsSelect().Operand())
case ast.CallKind:
call := e.AsCall()
if call.Target() != nil {
exprs = append(exprs, call.Target())
}
exprs = append(exprs, call.Args()...)
case ast.ComprehensionKind:
compre := e.AsComprehension()
exprs = append(exprs,
compre.IterRange(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result())
case ast.ListKind:
list := e.AsList()
exprs = append(exprs, list.Elements()...)
case ast.MapKind:
for _, entry := range e.AsMap().Entries() {
e := entry.AsMapEntry()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
exprs = append(exprs, e.Key())
exprs = append(exprs, e.Value())
}
case ast.StructKind:
for _, entry := range e.AsStruct().Fields() {
f := entry.AsStructField()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
exprs = append(exprs, f.Value())
}
}
}
}