// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ast import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) // NavigableExpr represents the base navigable expression value with methods to inspect the // parent and child expressions. type NavigableExpr interface { Expr // Type of the expression. // // If the expression is type-checked, the type check metadata is returned. If the expression // has not been type-checked, the types.DynType value is returned. Type() *types.Type // Parent returns the parent expression node, if one exists. Parent() (NavigableExpr, bool) // Children returns a list of child expression nodes. Children() []NavigableExpr // Depth indicates the depth in the expression tree. // // The root expression has depth 0. Depth() int } // NavigateAST converts an AST to a NavigableExpr func NavigateAST(ast *AST) NavigableExpr { return NavigateExpr(ast, ast.Expr()) } // NavigateExpr creates a NavigableExpr whose type information is backed by the input AST. // // If the expression is already a NavigableExpr, the parent and depth information will be // propagated on the new NavigableExpr value; otherwise, the expr value will be treated // as though it is the root of the expression graph with a depth of 0. func NavigateExpr(ast *AST, expr Expr) NavigableExpr { depth := 0 var parent NavigableExpr = nil if nav, ok := expr.(NavigableExpr); ok { depth = nav.Depth() parent, _ = nav.Parent() } return newNavigableExpr(ast, parent, expr, depth) } // ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. // // This function type should be use with the `Match` and `MatchList` calls. type ExprMatcher func(NavigableExpr) bool // ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr // is comprised of all constant values, such as a simple literal or even list and map literal. func ConstantValueMatcher() ExprMatcher { return matchIsConstantValue } // KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches // the specified `kind`. func KindMatcher(kind ExprKind) ExprMatcher { return func(e NavigableExpr) bool { return e.Kind() == kind } } // FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose // function name is equal to `funcName`. func FunctionMatcher(funcName string) ExprMatcher { return func(e NavigableExpr) bool { if e.Kind() != CallKind { return false } return e.AsCall().FunctionName() == funcName } } // AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. // // Such a result would work well with subsequent MatchList calls. func AllMatcher() ExprMatcher { return func(NavigableExpr) bool { return true } } // MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values // matching the input criteria in post-order (bottom up). func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { matches := []NavigableExpr{} navVisitor := &baseVisitor{ visitExpr: func(e Expr) { nav := e.(NavigableExpr) if matcher(nav) { matches = append(matches, nav) } }, } visit(expr, navVisitor, postOrder, 0, 0) return matches } // MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a // subset of NavigableExpr values which match. func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { matches := []NavigableExpr{} navVisitor := &baseVisitor{ visitExpr: func(e Expr) { nav := e.(NavigableExpr) if matcher(nav) { matches = append(matches, nav) } }, } for _, expr := range exprs { visit(expr, navVisitor, postOrder, 0, 1) } return matches } // Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph. type Visitor interface { // VisitExpr visits the input expression. VisitExpr(Expr) // VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry. VisitEntryExpr(EntryExpr) } type baseVisitor struct { visitExpr func(Expr) visitEntryExpr func(EntryExpr) } // VisitExpr visits the Expr if the internal expr visitor has been configured. func (v *baseVisitor) VisitExpr(e Expr) { if v.visitExpr != nil { v.visitExpr(e) } } // VisitEntryExpr visits the entry if the internal expr entry visitor has been configured. func (v *baseVisitor) VisitEntryExpr(e EntryExpr) { if v.visitEntryExpr != nil { v.visitEntryExpr(e) } } // NewExprVisitor creates a visitor which only visits expression nodes. func NewExprVisitor(v func(Expr)) Visitor { return &baseVisitor{ visitExpr: v, visitEntryExpr: nil, } } // PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up). func PostOrderVisit(expr Expr, visitor Visitor) { visit(expr, visitor, postOrder, 0, 0) } // PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down). func PreOrderVisit(expr Expr, visitor Visitor) { visit(expr, visitor, preOrder, 0, 0) } type visitOrder int const ( preOrder = iota + 1 postOrder ) // TODO: consider exposing a way to configure a limit for the max visit depth. // It's possible that we could want to configure this on the NewExprVisitor() // and through MatchDescendents() / MaxID(). func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) { if maxDepth > 0 && depth == maxDepth { return } if order == preOrder { visitor.VisitExpr(expr) } switch expr.Kind() { case CallKind: c := expr.AsCall() if c.IsMemberFunction() { visit(c.Target(), visitor, order, depth+1, maxDepth) } for _, arg := range c.Args() { visit(arg, visitor, order, depth+1, maxDepth) } case ComprehensionKind: c := expr.AsComprehension() visit(c.IterRange(), visitor, order, depth+1, maxDepth) visit(c.AccuInit(), visitor, order, depth+1, maxDepth) visit(c.LoopCondition(), visitor, order, depth+1, maxDepth) visit(c.LoopStep(), visitor, order, depth+1, maxDepth) visit(c.Result(), visitor, order, depth+1, maxDepth) case ListKind: l := expr.AsList() for _, elem := range l.Elements() { visit(elem, visitor, order, depth+1, maxDepth) } case MapKind: m := expr.AsMap() for _, e := range m.Entries() { if order == preOrder { visitor.VisitEntryExpr(e) } entry := e.AsMapEntry() visit(entry.Key(), visitor, order, depth+1, maxDepth) visit(entry.Value(), visitor, order, depth+1, maxDepth) if order == postOrder { visitor.VisitEntryExpr(e) } } case SelectKind: visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth) case StructKind: s := expr.AsStruct() for _, f := range s.Fields() { visitor.VisitEntryExpr(f) visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth) } } if order == postOrder { visitor.VisitExpr(expr) } } func matchIsConstantValue(e NavigableExpr) bool { if e.Kind() == LiteralKind { return true } if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { for _, child := range e.Children() { if !matchIsConstantValue(child) { return false } } return true } return false } func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr, depth int) NavigableExpr { // Reduce navigable expression nesting by unwrapping the embedded Expr value. if nav, ok := expr.(*navigableExprImpl); ok { expr = nav.Expr } nav := &navigableExprImpl{ Expr: expr, depth: depth, ast: ast, parent: parent, createChildren: getChildFactory(expr), } return nav } type navigableExprImpl struct { Expr depth int ast *AST parent NavigableExpr createChildren childFactory } func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { if nav.parent != nil { return nav.parent, true } return nil, false } func (nav *navigableExprImpl) ID() int64 { return nav.Expr.ID() } func (nav *navigableExprImpl) Kind() ExprKind { return nav.Expr.Kind() } func (nav *navigableExprImpl) Type() *types.Type { return nav.ast.GetType(nav.ID()) } func (nav *navigableExprImpl) Children() []NavigableExpr { return nav.createChildren(nav) } func (nav *navigableExprImpl) Depth() int { return nav.depth } func (nav *navigableExprImpl) AsCall() CallExpr { return navigableCallImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsComprehension() ComprehensionExpr { return navigableComprehensionImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsIdent() string { return nav.Expr.AsIdent() } func (nav *navigableExprImpl) AsList() ListExpr { return navigableListImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsLiteral() ref.Val { return nav.Expr.AsLiteral() } func (nav *navigableExprImpl) AsMap() MapExpr { return navigableMapImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsSelect() SelectExpr { return navigableSelectImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsStruct() StructExpr { return navigableStructImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr { return newNavigableExpr(nav.ast, nav, e, nav.depth+1) } func (nav *navigableExprImpl) isExpr() {} type navigableCallImpl struct { *navigableExprImpl } func (call navigableCallImpl) FunctionName() string { return call.Expr.AsCall().FunctionName() } func (call navigableCallImpl) IsMemberFunction() bool { return call.Expr.AsCall().IsMemberFunction() } func (call navigableCallImpl) Target() Expr { t := call.Expr.AsCall().Target() if t != nil { return call.createChild(t) } return nil } func (call navigableCallImpl) Args() []Expr { args := call.Expr.AsCall().Args() navArgs := make([]Expr, len(args)) for i, a := range args { navArgs[i] = call.createChild(a) } return navArgs } type navigableComprehensionImpl struct { *navigableExprImpl } func (comp navigableComprehensionImpl) IterRange() Expr { return comp.createChild(comp.Expr.AsComprehension().IterRange()) } func (comp navigableComprehensionImpl) IterVar() string { return comp.Expr.AsComprehension().IterVar() } func (comp navigableComprehensionImpl) AccuVar() string { return comp.Expr.AsComprehension().AccuVar() } func (comp navigableComprehensionImpl) AccuInit() Expr { return comp.createChild(comp.Expr.AsComprehension().AccuInit()) } func (comp navigableComprehensionImpl) LoopCondition() Expr { return comp.createChild(comp.Expr.AsComprehension().LoopCondition()) } func (comp navigableComprehensionImpl) LoopStep() Expr { return comp.createChild(comp.Expr.AsComprehension().LoopStep()) } func (comp navigableComprehensionImpl) Result() Expr { return comp.createChild(comp.Expr.AsComprehension().Result()) } type navigableListImpl struct { *navigableExprImpl } func (l navigableListImpl) Elements() []Expr { pbElems := l.Expr.AsList().Elements() elems := make([]Expr, len(pbElems)) for i := 0; i < len(pbElems); i++ { elems[i] = l.createChild(pbElems[i]) } return elems } func (l navigableListImpl) IsOptional(index int32) bool { return l.Expr.AsList().IsOptional(index) } func (l navigableListImpl) OptionalIndices() []int32 { return l.Expr.AsList().OptionalIndices() } func (l navigableListImpl) Size() int { return l.Expr.AsList().Size() } type navigableMapImpl struct { *navigableExprImpl } func (m navigableMapImpl) Entries() []EntryExpr { mapExpr := m.Expr.AsMap() entries := make([]EntryExpr, len(mapExpr.Entries())) for i, e := range mapExpr.Entries() { entry := e.AsMapEntry() entries[i] = &entryExpr{ id: e.ID(), entryExprKindCase: navigableEntryImpl{ key: m.createChild(entry.Key()), val: m.createChild(entry.Value()), isOpt: entry.IsOptional(), }, } } return entries } func (m navigableMapImpl) Size() int { return m.Expr.AsMap().Size() } type navigableEntryImpl struct { key NavigableExpr val NavigableExpr isOpt bool } func (e navigableEntryImpl) Kind() EntryExprKind { return MapEntryKind } func (e navigableEntryImpl) Key() Expr { return e.key } func (e navigableEntryImpl) Value() Expr { return e.val } func (e navigableEntryImpl) IsOptional() bool { return e.isOpt } func (e navigableEntryImpl) renumberIDs(IDGenerator) {} func (e navigableEntryImpl) isEntryExpr() {} type navigableSelectImpl struct { *navigableExprImpl } func (sel navigableSelectImpl) FieldName() string { return sel.Expr.AsSelect().FieldName() } func (sel navigableSelectImpl) IsTestOnly() bool { return sel.Expr.AsSelect().IsTestOnly() } func (sel navigableSelectImpl) Operand() Expr { return sel.createChild(sel.Expr.AsSelect().Operand()) } type navigableStructImpl struct { *navigableExprImpl } func (s navigableStructImpl) TypeName() string { return s.Expr.AsStruct().TypeName() } func (s navigableStructImpl) Fields() []EntryExpr { fieldInits := s.Expr.AsStruct().Fields() fields := make([]EntryExpr, len(fieldInits)) for i, f := range fieldInits { field := f.AsStructField() fields[i] = &entryExpr{ id: f.ID(), entryExprKindCase: navigableFieldImpl{ name: field.Name(), val: s.createChild(field.Value()), isOpt: field.IsOptional(), }, } } return fields } type navigableFieldImpl struct { name string val NavigableExpr isOpt bool } func (f navigableFieldImpl) Kind() EntryExprKind { return StructFieldKind } func (f navigableFieldImpl) Name() string { return f.name } func (f navigableFieldImpl) Value() Expr { return f.val } func (f navigableFieldImpl) IsOptional() bool { return f.isOpt } func (f navigableFieldImpl) renumberIDs(IDGenerator) {} func (f navigableFieldImpl) isEntryExpr() {} func getChildFactory(expr Expr) childFactory { if expr == nil { return noopFactory } switch expr.Kind() { case LiteralKind: return noopFactory case IdentKind: return noopFactory case SelectKind: return selectFactory case CallKind: return callArgFactory case ListKind: return listElemFactory case MapKind: return mapEntryFactory case StructKind: return structEntryFactory case ComprehensionKind: return comprehensionFactory default: return noopFactory } } type childFactory func(*navigableExprImpl) []NavigableExpr func noopFactory(*navigableExprImpl) []NavigableExpr { return nil } func selectFactory(nav *navigableExprImpl) []NavigableExpr { return []NavigableExpr{nav.createChild(nav.AsSelect().Operand())} } func callArgFactory(nav *navigableExprImpl) []NavigableExpr { call := nav.Expr.AsCall() argCount := len(call.Args()) if call.IsMemberFunction() { argCount++ } navExprs := make([]NavigableExpr, argCount) i := 0 if call.IsMemberFunction() { navExprs[i] = nav.createChild(call.Target()) i++ } for _, arg := range call.Args() { navExprs[i] = nav.createChild(arg) i++ } return navExprs } func listElemFactory(nav *navigableExprImpl) []NavigableExpr { l := nav.Expr.AsList() navExprs := make([]NavigableExpr, len(l.Elements())) for i, e := range l.Elements() { navExprs[i] = nav.createChild(e) } return navExprs } func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { s := nav.Expr.AsStruct() entries := make([]NavigableExpr, len(s.Fields())) for i, e := range s.Fields() { f := e.AsStructField() entries[i] = nav.createChild(f.Value()) } return entries } func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { m := nav.Expr.AsMap() entries := make([]NavigableExpr, len(m.Entries())*2) j := 0 for _, e := range m.Entries() { mapEntry := e.AsMapEntry() entries[j] = nav.createChild(mapEntry.Key()) entries[j+1] = nav.createChild(mapEntry.Value()) j += 2 } return entries } func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { compre := nav.Expr.AsComprehension() return []NavigableExpr{ nav.createChild(compre.IterRange()), nav.createChild(compre.AccuInit()), nav.createChild(compre.LoopCondition()), nav.createChild(compre.LoopStep()), nav.createChild(compre.Result()), } }