// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ast import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // ExprKind represents the expression node kind. type ExprKind int const ( // UnspecifiedKind represents an unset expression with no specified properties. UnspecifiedKind ExprKind = iota // LiteralKind represents a primitive scalar literal. LiteralKind // IdentKind represents a simple variable, constant, or type identifier. IdentKind // SelectKind represents a field selection expression. SelectKind // CallKind represents a function call. CallKind // ListKind represents a list literal expression. ListKind // MapKind represents a map literal expression. MapKind // StructKind represents a struct literal expression. StructKind // ComprehensionKind represents a comprehension expression generated by a macro. ComprehensionKind ) // NavigateCheckedAST converts a CheckedAST to a NavigableExpr func NavigateCheckedAST(ast *CheckedAST) NavigableExpr { return newNavigableExpr(nil, ast.Expr, ast.TypeMap) } // ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. // // This function type should be use with the `Match` and `MatchList` calls. type ExprMatcher func(NavigableExpr) bool // ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr // is comprised of all constant values, such as a simple literal or even list and map literal. func ConstantValueMatcher() ExprMatcher { return matchIsConstantValue } // KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches // the specified `kind`. func KindMatcher(kind ExprKind) ExprMatcher { return func(e NavigableExpr) bool { return e.Kind() == kind } } // FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose // function name is equal to `funcName`. func FunctionMatcher(funcName string) ExprMatcher { return func(e NavigableExpr) bool { if e.Kind() != CallKind { return false } return e.AsCall().FunctionName() == funcName } } // AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. // // Such a result would work well with subsequent MatchList calls. func AllMatcher() ExprMatcher { return func(NavigableExpr) bool { return true } } // MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the // descendants which match. func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { return matchListInternal([]NavigableExpr{expr}, matcher, true) } // MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a // subset of NavigableExpr values which match. func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { visit := make([]NavigableExpr, len(exprs)) copy(visit, exprs) return matchListInternal(visit, matcher, false) } func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr { var matched []NavigableExpr for len(visit) != 0 { e := visit[0] if matcher(e) { matched = append(matched, e) } if visitDescendants { visit = append(visit[1:], e.Children()...) } else { visit = visit[1:] } } return matched } func matchIsConstantValue(e NavigableExpr) bool { if e.Kind() == LiteralKind { return true } if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { for _, child := range e.Children() { if !matchIsConstantValue(child) { return false } } return true } return false } // NavigableExpr represents the base navigable expression value. // // Depending on the `Kind()` value, the NavigableExpr may be converted to a concrete expression types // as indicated by the `As` methods. // // NavigableExpr values and their concrete expression types should be nil-safe. Conversion of an expr // to the wrong kind should produce a nil value. type NavigableExpr interface { // ID of the expression as it appears in the AST ID() int64 // Kind of the expression node. See ExprKind for the valid enum values. Kind() ExprKind // Type of the expression node. Type() *types.Type // Parent returns the parent expression node, if one exists. Parent() (NavigableExpr, bool) // Children returns a list of child expression nodes. Children() []NavigableExpr // ToExpr adapts this NavigableExpr to a protobuf representation. ToExpr() *exprpb.Expr // AsCall adapts the expr into a NavigableCallExpr // // The Kind() must be equal to a CallKind for the conversion to be well-defined. AsCall() NavigableCallExpr // AsComprehension adapts the expr into a NavigableComprehensionExpr. // // The Kind() must be equal to a ComprehensionKind for the conversion to be well-defined. AsComprehension() NavigableComprehensionExpr // AsIdent adapts the expr into an identifier string. // // The Kind() must be equal to an IdentKind for the conversion to be well-defined. AsIdent() string // AsLiteral adapts the expr into a constant ref.Val. // // The Kind() must be equal to a LiteralKind for the conversion to be well-defined. AsLiteral() ref.Val // AsList adapts the expr into a NavigableListExpr. // // The Kind() must be equal to a ListKind for the conversion to be well-defined. AsList() NavigableListExpr // AsMap adapts the expr into a NavigableMapExpr. // // The Kind() must be equal to a MapKind for the conversion to be well-defined. AsMap() NavigableMapExpr // AsSelect adapts the expr into a NavigableSelectExpr. // // The Kind() must be equal to a SelectKind for the conversion to be well-defined. AsSelect() NavigableSelectExpr // AsStruct adapts the expr into a NavigableStructExpr. // // The Kind() must be equal to a StructKind for the conversion to be well-defined. AsStruct() NavigableStructExpr // marker interface method isNavigable() } // NavigableCallExpr defines an interface for inspecting a function call and its arugments. type NavigableCallExpr interface { // FunctionName returns the name of the function. FunctionName() string // Target returns the target of the expression if one is present. Target() NavigableExpr // Args returns the list of call arguments, excluding the target. Args() []NavigableExpr // ReturnType returns the result type of the call. ReturnType() *types.Type // marker interface method isNavigable() } // NavigableListExpr defines an interface for inspecting a list literal expression. type NavigableListExpr interface { // Elements returns the list elements as navigable expressions. Elements() []NavigableExpr // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 // Size returns the number of elements in the list. Size() int // marker interface method isNavigable() } // NavigableSelectExpr defines an interface for inspecting a select expression. type NavigableSelectExpr interface { // Operand returns the selection operand expression. Operand() NavigableExpr // FieldName returns the field name being selected from the operand. FieldName() string // IsTestOnly indicates whether the select expression is a presence test generated by a macro. IsTestOnly() bool // marker interface method isNavigable() } // NavigableMapExpr defines an interface for inspecting a map expression. type NavigableMapExpr interface { // Entries returns the map key value pairs as NavigableEntry values. Entries() []NavigableEntry // Size returns the number of entries in the map. Size() int // marker interface method isNavigable() } // NavigableEntry defines an interface for inspecting a map entry. type NavigableEntry interface { // Key returns the map entry key expression. Key() NavigableExpr // Value returns the map entry value expression. Value() NavigableExpr // IsOptional returns whether the entry is optional. IsOptional() bool // marker interface method isNavigable() } // NavigableStructExpr defines an interfaces for inspecting a struct and its field initializers. type NavigableStructExpr interface { // TypeName returns the struct type name. TypeName() string // Fields returns the set of field initializers in the struct expression as NavigableField values. Fields() []NavigableField // marker interface method isNavigable() } // NavigableField defines an interface for inspecting a struct field initialization. type NavigableField interface { // FieldName returns the name of the field. FieldName() string // Value returns the field initialization expression. Value() NavigableExpr // IsOptional returns whether the field is optional. IsOptional() bool // marker interface method isNavigable() } // NavigableComprehensionExpr defines an interface for inspecting a comprehension expression. type NavigableComprehensionExpr interface { // IterRange returns the iteration range expression. IterRange() NavigableExpr // IterVar returns the iteration variable name. IterVar() string // AccuVar returns the accumulation variable name. AccuVar() string // AccuInit returns the accumulation variable initialization expression. AccuInit() NavigableExpr // LoopCondition returns the loop condition expression. LoopCondition() NavigableExpr // LoopStep returns the loop step expression. LoopStep() NavigableExpr // Result returns the comprehension result expression. Result() NavigableExpr // marker interface method isNavigable() } func newNavigableExpr(parent NavigableExpr, expr *exprpb.Expr, typeMap map[int64]*types.Type) NavigableExpr { kind, factory := kindOf(expr) nav := &navigableExprImpl{ parent: parent, kind: kind, expr: expr, typeMap: typeMap, createChildren: factory, } return nav } type navigableExprImpl struct { parent NavigableExpr kind ExprKind expr *exprpb.Expr typeMap map[int64]*types.Type createChildren childFactory } func (nav *navigableExprImpl) ID() int64 { return nav.ToExpr().GetId() } func (nav *navigableExprImpl) Kind() ExprKind { return nav.kind } func (nav *navigableExprImpl) Type() *types.Type { if t, found := nav.typeMap[nav.ID()]; found { return t } return types.DynType } func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { if nav.parent != nil { return nav.parent, true } return nil, false } func (nav *navigableExprImpl) Children() []NavigableExpr { return nav.createChildren(nav) } func (nav *navigableExprImpl) ToExpr() *exprpb.Expr { return nav.expr } func (nav *navigableExprImpl) AsCall() NavigableCallExpr { return navigableCallImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsComprehension() NavigableComprehensionExpr { return navigableComprehensionImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsIdent() string { return nav.ToExpr().GetIdentExpr().GetName() } func (nav *navigableExprImpl) AsLiteral() ref.Val { if nav.Kind() != LiteralKind { return nil } val, err := ConstantToVal(nav.ToExpr().GetConstExpr()) if err != nil { panic(err) } return val } func (nav *navigableExprImpl) AsList() NavigableListExpr { return navigableListImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsMap() NavigableMapExpr { return navigableMapImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsSelect() NavigableSelectExpr { return navigableSelectImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) AsStruct() NavigableStructExpr { return navigableStructImpl{navigableExprImpl: nav} } func (nav *navigableExprImpl) createChild(e *exprpb.Expr) NavigableExpr { return newNavigableExpr(nav, e, nav.typeMap) } func (nav *navigableExprImpl) isNavigable() {} type navigableCallImpl struct { *navigableExprImpl } func (call navigableCallImpl) FunctionName() string { return call.ToExpr().GetCallExpr().GetFunction() } func (call navigableCallImpl) Target() NavigableExpr { t := call.ToExpr().GetCallExpr().GetTarget() if t != nil { return call.createChild(t) } return nil } func (call navigableCallImpl) Args() []NavigableExpr { args := call.ToExpr().GetCallExpr().GetArgs() navArgs := make([]NavigableExpr, len(args)) for i, a := range args { navArgs[i] = call.createChild(a) } return navArgs } func (call navigableCallImpl) ReturnType() *types.Type { return call.Type() } type navigableComprehensionImpl struct { *navigableExprImpl } func (comp navigableComprehensionImpl) IterRange() NavigableExpr { return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetIterRange()) } func (comp navigableComprehensionImpl) IterVar() string { return comp.ToExpr().GetComprehensionExpr().GetIterVar() } func (comp navigableComprehensionImpl) AccuVar() string { return comp.ToExpr().GetComprehensionExpr().GetAccuVar() } func (comp navigableComprehensionImpl) AccuInit() NavigableExpr { return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetAccuInit()) } func (comp navigableComprehensionImpl) LoopCondition() NavigableExpr { return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopCondition()) } func (comp navigableComprehensionImpl) LoopStep() NavigableExpr { return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopStep()) } func (comp navigableComprehensionImpl) Result() NavigableExpr { return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetResult()) } type navigableListImpl struct { *navigableExprImpl } func (l navigableListImpl) Elements() []NavigableExpr { return l.Children() } func (l navigableListImpl) OptionalIndices() []int32 { return l.ToExpr().GetListExpr().GetOptionalIndices() } func (l navigableListImpl) Size() int { return len(l.ToExpr().GetListExpr().GetElements()) } type navigableMapImpl struct { *navigableExprImpl } func (m navigableMapImpl) Entries() []NavigableEntry { mapExpr := m.ToExpr().GetStructExpr() entries := make([]NavigableEntry, len(mapExpr.GetEntries())) for i, e := range mapExpr.GetEntries() { entries[i] = navigableEntryImpl{ key: m.createChild(e.GetMapKey()), val: m.createChild(e.GetValue()), isOpt: e.GetOptionalEntry(), } } return entries } func (m navigableMapImpl) Size() int { return len(m.ToExpr().GetStructExpr().GetEntries()) } type navigableEntryImpl struct { key NavigableExpr val NavigableExpr isOpt bool } func (e navigableEntryImpl) Key() NavigableExpr { return e.key } func (e navigableEntryImpl) Value() NavigableExpr { return e.val } func (e navigableEntryImpl) IsOptional() bool { return e.isOpt } func (e navigableEntryImpl) isNavigable() {} type navigableSelectImpl struct { *navigableExprImpl } func (sel navigableSelectImpl) FieldName() string { return sel.ToExpr().GetSelectExpr().GetField() } func (sel navigableSelectImpl) IsTestOnly() bool { return sel.ToExpr().GetSelectExpr().GetTestOnly() } func (sel navigableSelectImpl) Operand() NavigableExpr { return sel.createChild(sel.ToExpr().GetSelectExpr().GetOperand()) } type navigableStructImpl struct { *navigableExprImpl } func (s navigableStructImpl) TypeName() string { return s.ToExpr().GetStructExpr().GetMessageName() } func (s navigableStructImpl) Fields() []NavigableField { fieldInits := s.ToExpr().GetStructExpr().GetEntries() fields := make([]NavigableField, len(fieldInits)) for i, f := range fieldInits { fields[i] = navigableFieldImpl{ name: f.GetFieldKey(), val: s.createChild(f.GetValue()), isOpt: f.GetOptionalEntry(), } } return fields } type navigableFieldImpl struct { name string val NavigableExpr isOpt bool } func (f navigableFieldImpl) FieldName() string { return f.name } func (f navigableFieldImpl) Value() NavigableExpr { return f.val } func (f navigableFieldImpl) IsOptional() bool { return f.isOpt } func (f navigableFieldImpl) isNavigable() {} func kindOf(expr *exprpb.Expr) (ExprKind, childFactory) { switch expr.GetExprKind().(type) { case *exprpb.Expr_ConstExpr: return LiteralKind, noopFactory case *exprpb.Expr_IdentExpr: return IdentKind, noopFactory case *exprpb.Expr_SelectExpr: return SelectKind, selectFactory case *exprpb.Expr_CallExpr: return CallKind, callArgFactory case *exprpb.Expr_ListExpr: return ListKind, listElemFactory case *exprpb.Expr_StructExpr: if expr.GetStructExpr().GetMessageName() != "" { return StructKind, structEntryFactory } return MapKind, mapEntryFactory case *exprpb.Expr_ComprehensionExpr: return ComprehensionKind, comprehensionFactory default: return UnspecifiedKind, noopFactory } } type childFactory func(*navigableExprImpl) []NavigableExpr func noopFactory(*navigableExprImpl) []NavigableExpr { return nil } func selectFactory(nav *navigableExprImpl) []NavigableExpr { return []NavigableExpr{ nav.createChild(nav.ToExpr().GetSelectExpr().GetOperand()), } } func callArgFactory(nav *navigableExprImpl) []NavigableExpr { call := nav.ToExpr().GetCallExpr() argCount := len(call.GetArgs()) if call.GetTarget() != nil { argCount++ } navExprs := make([]NavigableExpr, argCount) i := 0 if call.GetTarget() != nil { navExprs[i] = nav.createChild(call.GetTarget()) i++ } for _, arg := range call.GetArgs() { navExprs[i] = nav.createChild(arg) i++ } return navExprs } func listElemFactory(nav *navigableExprImpl) []NavigableExpr { l := nav.ToExpr().GetListExpr() navExprs := make([]NavigableExpr, len(l.GetElements())) for i, e := range l.GetElements() { navExprs[i] = nav.createChild(e) } return navExprs } func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { s := nav.ToExpr().GetStructExpr() entries := make([]NavigableExpr, len(s.GetEntries())) for i, e := range s.GetEntries() { entries[i] = nav.createChild(e.GetValue()) } return entries } func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { s := nav.ToExpr().GetStructExpr() entries := make([]NavigableExpr, len(s.GetEntries())*2) j := 0 for _, e := range s.GetEntries() { entries[j] = nav.createChild(e.GetMapKey()) entries[j+1] = nav.createChild(e.GetValue()) j += 2 } return entries } func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { compre := nav.ToExpr().GetComprehensionExpr() return []NavigableExpr{ nav.createChild(compre.GetIterRange()), nav.createChild(compre.GetAccuInit()), nav.createChild(compre.GetLoopCondition()), nav.createChild(compre.GetLoopStep()), nav.createChild(compre.GetResult()), } }