// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package parser import ( "sync" antlr "github.com/antlr4-go/antlr/v4" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) type parserHelper struct { exprFactory ast.ExprFactory source common.Source sourceInfo *ast.SourceInfo nextID int64 } func newParserHelper(source common.Source, fac ast.ExprFactory) *parserHelper { return &parserHelper{ exprFactory: fac, source: source, sourceInfo: ast.NewSourceInfo(source), nextID: 1, } } func (p *parserHelper) getSourceInfo() *ast.SourceInfo { return p.sourceInfo } func (p *parserHelper) newLiteral(ctx any, value ref.Val) ast.Expr { return p.exprFactory.NewLiteral(p.newID(ctx), value) } func (p *parserHelper) newLiteralBool(ctx any, value bool) ast.Expr { return p.newLiteral(ctx, types.Bool(value)) } func (p *parserHelper) newLiteralString(ctx any, value string) ast.Expr { return p.newLiteral(ctx, types.String(value)) } func (p *parserHelper) newLiteralBytes(ctx any, value []byte) ast.Expr { return p.newLiteral(ctx, types.Bytes(value)) } func (p *parserHelper) newLiteralInt(ctx any, value int64) ast.Expr { return p.newLiteral(ctx, types.Int(value)) } func (p *parserHelper) newLiteralUint(ctx any, value uint64) ast.Expr { return p.newLiteral(ctx, types.Uint(value)) } func (p *parserHelper) newLiteralDouble(ctx any, value float64) ast.Expr { return p.newLiteral(ctx, types.Double(value)) } func (p *parserHelper) newIdent(ctx any, name string) ast.Expr { return p.exprFactory.NewIdent(p.newID(ctx), name) } func (p *parserHelper) newSelect(ctx any, operand ast.Expr, field string) ast.Expr { return p.exprFactory.NewSelect(p.newID(ctx), operand, field) } func (p *parserHelper) newPresenceTest(ctx any, operand ast.Expr, field string) ast.Expr { return p.exprFactory.NewPresenceTest(p.newID(ctx), operand, field) } func (p *parserHelper) newGlobalCall(ctx any, function string, args ...ast.Expr) ast.Expr { return p.exprFactory.NewCall(p.newID(ctx), function, args...) } func (p *parserHelper) newReceiverCall(ctx any, function string, target ast.Expr, args ...ast.Expr) ast.Expr { return p.exprFactory.NewMemberCall(p.newID(ctx), function, target, args...) } func (p *parserHelper) newList(ctx any, elements []ast.Expr, optionals ...int32) ast.Expr { return p.exprFactory.NewList(p.newID(ctx), elements, optionals) } func (p *parserHelper) newMap(ctx any, entries ...ast.EntryExpr) ast.Expr { return p.exprFactory.NewMap(p.newID(ctx), entries) } func (p *parserHelper) newMapEntry(entryID int64, key ast.Expr, value ast.Expr, optional bool) ast.EntryExpr { return p.exprFactory.NewMapEntry(entryID, key, value, optional) } func (p *parserHelper) newObject(ctx any, typeName string, fields ...ast.EntryExpr) ast.Expr { return p.exprFactory.NewStruct(p.newID(ctx), typeName, fields) } func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Expr, optional bool) ast.EntryExpr { return p.exprFactory.NewStructField(fieldID, field, value, optional) } func (p *parserHelper) newComprehension(ctx any, iterRange ast.Expr, iterVar string, accuVar string, accuInit ast.Expr, condition ast.Expr, step ast.Expr, result ast.Expr) ast.Expr { return p.exprFactory.NewComprehension( p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result) } func (p *parserHelper) newID(ctx any) int64 { if id, isID := ctx.(int64); isID { return id } return p.id(ctx) } func (p *parserHelper) newExpr(ctx any) ast.Expr { return p.exprFactory.NewUnspecifiedExpr(p.newID(ctx)) } func (p *parserHelper) id(ctx any) int64 { var offset ast.OffsetRange switch c := ctx.(type) { case antlr.ParserRuleContext: start, stop := c.GetStart(), c.GetStop() if stop == nil { stop = start } offset.Start = p.sourceInfo.ComputeOffset(int32(start.GetLine()), int32(start.GetColumn())) offset.Stop = p.sourceInfo.ComputeOffset(int32(stop.GetLine()), int32(stop.GetColumn())) case antlr.Token: offset.Start = p.sourceInfo.ComputeOffset(int32(c.GetLine()), int32(c.GetColumn())) offset.Stop = offset.Start case common.Location: offset.Start = p.sourceInfo.ComputeOffset(int32(c.Line()), int32(c.Column())) offset.Stop = offset.Start case ast.OffsetRange: offset = c default: // This should only happen if the ctx is nil return -1 } id := p.nextID p.sourceInfo.SetOffsetRange(id, offset) p.nextID++ return id } func (p *parserHelper) getLocation(id int64) common.Location { return p.sourceInfo.GetStartLocation(id) } // buildMacroCallArg iterates the expression and returns a new expression // where all macros have been replaced by their IDs in MacroCalls func (p *parserHelper) buildMacroCallArg(expr ast.Expr) ast.Expr { if _, found := p.sourceInfo.GetMacroCall(expr.ID()); found { return p.exprFactory.NewUnspecifiedExpr(expr.ID()) } switch expr.Kind() { case ast.CallKind: // Iterate the AST from `expr` recursively looking for macros. Because we are at most // starting from the top level macro, this recursion is bounded by the size of the AST. This // means that the depth check on the AST during parsing will catch recursion overflows // before we get to here. call := expr.AsCall() macroArgs := make([]ast.Expr, len(call.Args())) for index, arg := range call.Args() { macroArgs[index] = p.buildMacroCallArg(arg) } if !call.IsMemberFunction() { return p.exprFactory.NewCall(expr.ID(), call.FunctionName(), macroArgs...) } macroTarget := p.buildMacroCallArg(call.Target()) return p.exprFactory.NewMemberCall(expr.ID(), call.FunctionName(), macroTarget, macroArgs...) case ast.ListKind: list := expr.AsList() macroListArgs := make([]ast.Expr, list.Size()) for i, elem := range list.Elements() { macroListArgs[i] = p.buildMacroCallArg(elem) } return p.exprFactory.NewList(expr.ID(), macroListArgs, list.OptionalIndices()) } return expr } // addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target // that are macros, their ID will be stored instead for later self-lookups. func (p *parserHelper) addMacroCall(exprID int64, function string, target ast.Expr, args ...ast.Expr) { macroArgs := make([]ast.Expr, len(args)) for index, arg := range args { macroArgs[index] = p.buildMacroCallArg(arg) } if target == nil { p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewCall(0, function, macroArgs...)) return } macroTarget := target if _, found := p.sourceInfo.GetMacroCall(target.ID()); found { macroTarget = p.exprFactory.NewUnspecifiedExpr(target.ID()) } else { macroTarget = p.buildMacroCallArg(target) } p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewMemberCall(0, function, macroTarget, macroArgs...)) } // logicManager compacts logical trees into a more efficient structure which is semantically // equivalent with how the logic graph is constructed by the ANTLR parser. // // The purpose of the logicManager is to ensure a compact serialization format for the logical &&, || // operators which have a tendency to create long DAGs which are skewed in one direction. Since the // operators are commutative re-ordering the terms *must not* affect the evaluation result. // // The logic manager will either render the terms to N-chained && / || operators as a single logical // call with N-terms, or will rebalance the tree. Rebalancing the terms is a safe, if somewhat // controversial choice as it alters the traditional order of execution assumptions present in most // expressions. type logicManager struct { exprFactory ast.ExprFactory function string terms []ast.Expr ops []int64 variadicASTs bool } // newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term. func newVariadicLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ exprFactory: fac, function: function, terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: true, } } // newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term. func newBalancingLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ exprFactory: fac, function: function, terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: false, } } // addTerm adds an operation identifier and term to the set of terms to be balanced. func (l *logicManager) addTerm(op int64, term ast.Expr) { l.terms = append(l.terms, term) l.ops = append(l.ops, op) } // toExpr renders the logic graph into an Expr value, either balancing a tree of logical // operations or creating a variadic representation of the logical operator. func (l *logicManager) toExpr() ast.Expr { if len(l.terms) == 1 { return l.terms[0] } if l.variadicASTs { return l.exprFactory.NewCall(l.ops[0], l.function, l.terms...) } return l.balancedTree(0, len(l.ops)-1) } // balancedTree recursively balances the terms provided to a commutative operator. func (l *logicManager) balancedTree(lo, hi int) ast.Expr { mid := (lo + hi + 1) / 2 var left ast.Expr if mid == lo { left = l.terms[mid] } else { left = l.balancedTree(lo, mid-1) } var right ast.Expr if mid == hi { right = l.terms[mid+1] } else { right = l.balancedTree(mid+1, hi) } return l.exprFactory.NewCall(l.ops[mid], l.function, left, right) } type exprHelper struct { *parserHelper id int64 } func (e *exprHelper) nextMacroID() int64 { return e.parserHelper.id(e.parserHelper.getLocation(e.id)) } // Copy implements the ExprHelper interface method by producing a copy of the input Expr value // with a fresh set of numeric identifiers the Expr and all its descendants. func (e *exprHelper) Copy(expr ast.Expr) ast.Expr { offsetRange, _ := e.parserHelper.sourceInfo.GetOffsetRange(expr.ID()) copyID := e.parserHelper.newID(offsetRange) switch expr.Kind() { case ast.LiteralKind: return e.exprFactory.NewLiteral(copyID, expr.AsLiteral()) case ast.IdentKind: return e.exprFactory.NewIdent(copyID, expr.AsIdent()) case ast.SelectKind: sel := expr.AsSelect() op := e.Copy(sel.Operand()) if sel.IsTestOnly() { return e.exprFactory.NewPresenceTest(copyID, op, sel.FieldName()) } return e.exprFactory.NewSelect(copyID, op, sel.FieldName()) case ast.CallKind: call := expr.AsCall() args := call.Args() argsCopy := make([]ast.Expr, len(args)) for i, arg := range args { argsCopy[i] = e.Copy(arg) } if !call.IsMemberFunction() { return e.exprFactory.NewCall(copyID, call.FunctionName(), argsCopy...) } return e.exprFactory.NewMemberCall(copyID, call.FunctionName(), e.Copy(call.Target()), argsCopy...) case ast.ListKind: list := expr.AsList() elems := list.Elements() elemsCopy := make([]ast.Expr, len(elems)) for i, elem := range elems { elemsCopy[i] = e.Copy(elem) } return e.exprFactory.NewList(copyID, elemsCopy, list.OptionalIndices()) case ast.MapKind: m := expr.AsMap() entries := m.Entries() entriesCopy := make([]ast.EntryExpr, len(entries)) for i, en := range entries { entry := en.AsMapEntry() entryID := e.nextMacroID() entriesCopy[i] = e.exprFactory.NewMapEntry(entryID, e.Copy(entry.Key()), e.Copy(entry.Value()), entry.IsOptional()) } return e.exprFactory.NewMap(copyID, entriesCopy) case ast.StructKind: s := expr.AsStruct() fields := s.Fields() fieldsCopy := make([]ast.EntryExpr, len(fields)) for i, f := range fields { field := f.AsStructField() fieldID := e.nextMacroID() fieldsCopy[i] = e.exprFactory.NewStructField(fieldID, field.Name(), e.Copy(field.Value()), field.IsOptional()) } return e.exprFactory.NewStruct(copyID, s.TypeName(), fieldsCopy) case ast.ComprehensionKind: compre := expr.AsComprehension() iterRange := e.Copy(compre.IterRange()) accuInit := e.Copy(compre.AccuInit()) cond := e.Copy(compre.LoopCondition()) step := e.Copy(compre.LoopStep()) result := e.Copy(compre.Result()) return e.exprFactory.NewComprehension(copyID, iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result) } return e.exprFactory.NewUnspecifiedExpr(copyID) } // NewLiteral implements the ExprHelper interface method. func (e *exprHelper) NewLiteral(value ref.Val) ast.Expr { return e.exprFactory.NewLiteral(e.nextMacroID(), value) } // NewList implements the ExprHelper interface method. func (e *exprHelper) NewList(elems ...ast.Expr) ast.Expr { return e.exprFactory.NewList(e.nextMacroID(), elems, []int32{}) } // NewMap implements the ExprHelper interface method. func (e *exprHelper) NewMap(entries ...ast.EntryExpr) ast.Expr { return e.exprFactory.NewMap(e.nextMacroID(), entries) } // NewMapEntry implements the ExprHelper interface method. func (e *exprHelper) NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr { return e.exprFactory.NewMapEntry(e.nextMacroID(), key, val, optional) } // NewStruct implements the ExprHelper interface method. func (e *exprHelper) NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr { return e.exprFactory.NewStruct(e.nextMacroID(), typeName, fieldInits) } // NewStructField implements the ExprHelper interface method. func (e *exprHelper) NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr { return e.exprFactory.NewStructField(e.nextMacroID(), field, init, optional) } // NewComprehension implements the ExprHelper interface method. func (e *exprHelper) NewComprehension( iterRange ast.Expr, iterVar string, accuVar string, accuInit ast.Expr, condition ast.Expr, step ast.Expr, result ast.Expr) ast.Expr { return e.exprFactory.NewComprehension( e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result) } // NewIdent implements the ExprHelper interface method. func (e *exprHelper) NewIdent(name string) ast.Expr { return e.exprFactory.NewIdent(e.nextMacroID(), name) } // NewAccuIdent implements the ExprHelper interface method. func (e *exprHelper) NewAccuIdent() ast.Expr { return e.exprFactory.NewAccuIdent(e.nextMacroID()) } // NewGlobalCall implements the ExprHelper interface method. func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr { return e.exprFactory.NewCall(e.nextMacroID(), function, args...) } // NewMemberCall implements the ExprHelper interface method. func (e *exprHelper) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { return e.exprFactory.NewMemberCall(e.nextMacroID(), function, target, args...) } // NewPresenceTest implements the ExprHelper interface method. func (e *exprHelper) NewPresenceTest(operand ast.Expr, field string) ast.Expr { return e.exprFactory.NewPresenceTest(e.nextMacroID(), operand, field) } // NewSelect implements the ExprHelper interface method. func (e *exprHelper) NewSelect(operand ast.Expr, field string) ast.Expr { return e.exprFactory.NewSelect(e.nextMacroID(), operand, field) } // OffsetLocation implements the ExprHelper interface method. func (e *exprHelper) OffsetLocation(exprID int64) common.Location { return e.parserHelper.sourceInfo.GetStartLocation(exprID) } // NewError associates an error message with a given expression id, populating the source offset location of the error if possible. func (e *exprHelper) NewError(exprID int64, message string) *common.Error { return common.NewError(exprID, message, e.OffsetLocation(exprID)) } var ( // Thread-safe pool of ExprHelper values to minimize alloc overhead of ExprHelper creations. exprHelperPool = &sync.Pool{ New: func() any { return &exprHelper{} }, } )