rebase: update kubernetes to latest

updating the kubernetes release to the
latest in main go.mod

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna
2024-08-19 10:01:33 +02:00
committed by mergify[bot]
parent 63c4c05b35
commit 5a66991bb3
2173 changed files with 98906 additions and 61334 deletions

View File

@ -10,9 +10,12 @@ go_library(
"cel.go",
"decls.go",
"env.go",
"folding.go",
"io.go",
"inlining.go",
"library.go",
"macro.go",
"optimizer.go",
"options.go",
"program.go",
"validator.go",
@ -56,7 +59,11 @@ go_test(
"cel_test.go",
"decls_test.go",
"env_test.go",
"folding_test.go",
"io_test.go",
"inlining_test.go",
"optimizer_test.go",
"validator_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
@ -70,6 +77,7 @@ go_test(
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//ext:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",

View File

@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}
func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@ -38,26 +38,42 @@ type Source = common.Source
// Ast representing the checked or unchecked expression, its source, and related metadata such as
// source position information.
type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
source Source
impl *celast.AST
}
// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
return ast.impl
}
// Expr returns the proto serializable instance of the parsed/checked expression.
//
// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr()
// the result instead.
func (ast *Ast) Expr() *exprpb.Expr {
return ast.expr
if ast == nil {
return nil
}
pbExpr, _ := celast.ExprToProto(ast.impl.Expr())
return pbExpr
}
// IsChecked returns whether the Ast value has been successfully type-checked.
func (ast *Ast) IsChecked() bool {
return ast.typeMap != nil && len(ast.typeMap) > 0
if ast == nil {
return false
}
return ast.impl.IsChecked()
}
// SourceInfo returns character offset and newline position information about expression elements.
func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
return ast.info
if ast == nil {
return nil
}
pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo())
return pbInfo
}
// ResultType returns the output type of the expression if the Ast has been type-checked, else
@ -65,9 +81,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return chkdecls.Dyn
}
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
@ -79,16 +92,18 @@ func (ast *Ast) ResultType() *exprpb.Type {
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
if ast == nil {
return types.ErrorType
}
return t
return ast.impl.GetType(ast.impl.Expr().ID())
}
// Source returns a view of the input used to create the Ast. This source may be complete or
// constructed from the SourceInfo.
func (ast *Ast) Source() Source {
if ast == nil {
return nil
}
return ast.source
}
@ -198,29 +213,28 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
// Note, errors aren't currently possible on the Ast to ParsedExpr conversion.
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
res, errs := checker.Check(pe, ast.Source(), chk)
checked, errs := checker.Check(ast.impl, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
source: ast.Source(),
impl: checked}
// Avoid creating a validator config if it's not needed.
if len(e.validators) == 0 {
return ast, nil
}
// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
@ -230,9 +244,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
v.Validate(e, vConfig, checked, iss)
}
if iss.Err() != nil {
return nil, iss
@ -429,16 +443,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src Source) (*Ast, *Issues) {
res, errs := e.prsr.Parse(src)
parsed, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
// Manually create the Ast to ensure that the text source information is propagated on
// subsequent calls to Check.
return &Ast{
source: src,
expr: res.GetExpr(),
info: res.GetSourceInfo()}, nil
return &Ast{source: src, impl: parsed}, nil
}
// Program generates an evaluable instance of the Ast within the environment (Env).
@ -534,8 +543,9 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
return nil, err
}
@ -556,16 +566,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}
// configure applies a series of EnvOptions to the current environment.
@ -707,7 +711,7 @@ type Error = common.Error
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
info *celast.SourceInfo
}
// NewIssues returns an Issues struct from a common.Errors object.
@ -718,7 +722,7 @@ func NewIssues(errs *common.Errors) *Issues {
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
@ -768,30 +772,7 @@ func (i *Issues) String() string {
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}
// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...)
}
// getStdEnv lazy initializes the CEL standard environment.
@ -822,6 +803,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}
// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}
// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//

559
vendor/github.com/google/cel-go/cel/folding.go generated vendored Normal file
View File

@ -0,0 +1,559 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// ConstantFoldingOption defines a functional option for configuring constant folding.
type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)
// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
//
// Defaults to 100 if not set.
func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
opt.maxFoldIterations = limit
return opt, nil
}
}
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
folder := &constantFoldingOptimizer{
maxFoldIterations: defaultMaxConstantFoldIterations,
}
var err error
for _, o := range opts {
folder, err = o(folder)
if err != nil {
return nil, err
}
}
return folder, nil
}
type constantFoldingOptimizer struct {
maxFoldIterations int
}
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
// select statements and then evaluates them and replaces the call site with the literal result.
//
// Note: only values which can be represented as literals in CEL syntax are supported.
func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
// If the expression could be folded because it's a non-strict call, and the
// branches are pruned, continue to the next fold.
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
// target comprehension node with the literal value if the evaluation succeeds.
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
tryFold(ctx, a, compre)
}
// If the output is a list, map, or struct which contains optional entries, then prune it
// to make sure that the optionals, if resolved, do not surface in the output literal.
pruneOptionalElements(ctx, root)
// Ensure that all intermediate values in the folded expression can be represented as valid
// CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
// to avoid extra allocations during this final pass through the AST.
ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() != ast.LiteralKind {
return
}
val := e.AsLiteral()
adapted, err := adaptLiteral(ctx, val)
if err != nil {
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
ctx.UpdateExpr(e, adapted)
}))
return a
}
// tryFold attempts to evaluate a sub-expression to a literal.
//
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
// the method will return an error.
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// Assume all context is needed to evaluate the expression.
subAST := &Ast{
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
}
prg, err := ctx.Program(subAST)
if err != nil {
return err
}
out, _, err := prg.Eval(NoVars())
if err != nil {
return err
}
// Update the fold expression to be a literal.
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}
// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
// constant folding step reimplements a portion of the evaluator.
func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
call := expr.AsCall()
args := call.Args()
switch call.FunctionName() {
case operators.LogicalAnd, operators.LogicalOr:
return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
case operators.Conditional:
cond := args[0]
truthy := args[1]
falsy := args[2]
if cond.Kind() != ast.LiteralKind {
return false
}
if cond.AsLiteral() == types.True {
ctx.UpdateExpr(expr, truthy)
} else {
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
}
}
return false
}
func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
shortcircuit := types.False
skip := types.True
if function == operators.LogicalOr {
shortcircuit = types.True
skip = types.False
}
newArgs := []ast.Expr{}
for _, arg := range args {
if arg.Kind() != ast.LiteralKind {
newArgs = append(newArgs, arg)
continue
}
if arg.AsLiteral() == skip {
continue
}
if arg.AsLiteral() == shortcircuit {
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
ctx.UpdateExpr(expr, newArgs[0])
return true
}
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}
// pruneOptionalElements works from the bottom up to resolve optional elements within
// aggregate literals.
//
// Note, many aggregate literals will be resolved as arguments to functions or select
// statements, so this method exists to handle the case where the literal could not be
// fully resolved or exists outside of a call, select, or comprehension context.
func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher)
for _, lit := range aggregateLiterals {
switch lit.Kind() {
case ast.ListKind:
pruneOptionalListElements(ctx, lit)
case ast.MapKind:
pruneOptionalMapEntries(ctx, lit)
case ast.StructKind:
pruneOptionalStructFields(ctx, lit)
}
}
}
func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
l := e.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
if len(optIndices) == 0 {
return
}
updatedElems := []ast.Expr{}
updatedIndices := []int32{}
newOptIndex := -1
for _, e := range elems {
newOptIndex++
if !l.IsOptional(int32(newOptIndex)) {
updatedElems = append(updatedElems, e)
continue
}
if e.Kind() != ast.LiteralKind {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
optElemVal, ok := e.AsLiteral().(*types.Optional)
if !ok {
updatedElems = append(updatedElems, e)
updatedIndices = append(updatedIndices, int32(newOptIndex))
continue
}
if !optElemVal.HasValue() {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}
func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
m := e.AsMap()
entries := m.Entries()
updatedEntries := []ast.EntryExpr{}
modified := false
for _, e := range entries {
entry := e.AsMapEntry()
key := entry.Key()
val := entry.Value()
// If the entry is not optional, or the value-side of the optional hasn't
// been resolved to a literal, then preserve the entry as-is.
if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
updatedEntries = append(updatedEntries, e)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedEntries = append(updatedEntries, e)
continue
}
// When the key is not a literal, but the value is, then it needs to be
// restored to an optional value.
if key.Kind() != ast.LiteralKind {
undoOptVal, err := adaptLiteral(ctx, optElemVal)
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}
func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
s := e.AsStruct()
fields := s.Fields()
updatedFields := []ast.EntryExpr{}
modified := false
for _, f := range fields {
field := f.AsStructField()
val := field.Value()
if !field.IsOptional() || val.Kind() != ast.LiteralKind {
updatedFields = append(updatedFields, f)
continue
}
optElemVal, ok := val.AsLiteral().(*types.Optional)
if !ok {
updatedFields = append(updatedFields, f)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}
// adaptLiteral converts a runtime CEL value to its equivalent literal expression.
//
// For strongly typed values, the type-provider will be used to reconstruct the fields
// which are present in the literal and their equivalent initialization values.
func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
switch t := val.Type().(type) {
case *types.Type:
switch t {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
return ctx.NewLiteral(val), nil
case types.DurationType:
return ctx.NewCall(
overloads.TypeConvertDuration,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.TimestampType:
return ctx.NewCall(
overloads.TypeConvertTimestamp,
ctx.NewLiteral(val.ConvertToType(types.StringType)),
), nil
case types.OptionalType:
opt := val.(*types.Optional)
if !opt.HasValue() {
return ctx.NewCall("optional.none"), nil
}
target, err := adaptLiteral(ctx, opt.GetValue())
if err != nil {
return nil, err
}
return ctx.NewCall("optional.of", target), nil
case types.TypeType:
return ctx.NewIdent(val.(*types.Type).TypeName()), nil
case types.ListType:
l, ok := val.(traits.Lister)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
elems := make([]ast.Expr, l.Size().(types.Int))
idx := 0
it := l.Iterator()
for it.HasNext() == types.True {
elemVal := it.Next()
elemExpr, err := adaptLiteral(ctx, elemVal)
if err != nil {
return nil, err
}
elems[idx] = elemExpr
idx++
}
return ctx.NewList(elems, []int32{}), nil
case types.MapType:
m, ok := val.(traits.Mapper)
if !ok {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
entries := make([]ast.EntryExpr, m.Size().(types.Int))
idx := 0
it := m.Iterator()
for it.HasNext() == types.True {
keyVal := it.Next()
keyExpr, err := adaptLiteral(ctx, keyVal)
if err != nil {
return nil, err
}
valVal := m.Get(keyVal)
valExpr, err := adaptLiteral(ctx, valVal)
if err != nil {
return nil, err
}
entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false)
idx++
}
return ctx.NewMap(entries), nil
default:
provider := ctx.CELTypeProvider()
fields, found := provider.FindStructFieldNames(t.TypeName())
if !found {
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
tester := val.(traits.FieldTester)
indexer := val.(traits.Indexer)
fieldInits := []ast.EntryExpr{}
for _, f := range fields {
field := types.String(f)
if tester.IsSet(field) != types.True {
continue
}
fieldVal := indexer.Get(field)
fieldExpr, err := adaptLiteral(ctx, fieldVal)
if err != nil {
return nil, err
}
fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false))
}
return ctx.NewStruct(t.TypeName(), fieldInits), nil
}
}
return nil, fmt.Errorf("failed to adapt %v to literal", val)
}
// constantExprMatcher matches calls, select statements, and comprehensions whose arguments
// are all constant scalar or aggregate literal values.
//
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
case ast.SelectKind:
sel := e.AsSelect() // guaranteed to be a navigable value
return constantMatcher(sel.Operand().(ast.NavigableExpr))
case ast.ComprehensionKind:
if isNestedComprehension(e) {
return false
}
vars := map[string]bool{}
constantExprs := true
visitor := ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() == ast.ComprehensionKind {
nested := e.AsComprehension()
vars[nested.AccuVar()] = true
vars[nested.IterVar()] = true
}
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
default:
return false
}
}
// constantCallMatcher identifies strict and non-strict calls which can be folded.
func constantCallMatcher(e ast.NavigableExpr) bool {
call := e.AsCall()
children := e.Children()
fnName := call.FunctionName()
if fnName == operators.LogicalAnd {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.LogicalOr {
for _, child := range children {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.Conditional {
cond := children[0]
if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType {
return true
}
}
if fnName == operators.In {
haystack := children[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
return true
}
needle := children[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
return true
}
}
}
}
// convert all other calls with constant arguments
for _, child := range children {
if !constantMatcher(child) {
return false
}
}
return true
}
func isNestedComprehension(e ast.NavigableExpr) bool {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.ComprehensionKind {
return true
}
parent, found = parent.Parent()
}
return false
}
func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind
}
var (
constantMatcher = ast.ConstantValueMatcher()
)
const (
defaultMaxConstantFoldIterations = 100
)

228
vendor/github.com/google/cel-go/cel/inlining.go generated vendored Normal file
View File

@ -0,0 +1,228 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
)
// InlineVariable holds a variable name to be matched and an AST representing
// the expression graph which should be used to replace it.
type InlineVariable struct {
name string
alias string
def *ast.AST
}
// Name returns the qualified variable or field selection to replace.
func (v *InlineVariable) Name() string {
return v.name
}
// Alias returns the alias to use when performing cel.bind() calls during inlining.
func (v *InlineVariable) Alias() string {
return v.alias
}
// Expr returns the inlined expression value.
func (v *InlineVariable) Expr() ast.Expr {
return v.def.Expr()
}
// Type indicates the inlined expression type.
func (v *InlineVariable) Type() *Type {
return v.def.GetType(v.def.Expr().ID())
}
// NewInlineVariable declares a variable name to be replaced by a checked expression.
func NewInlineVariable(name string, definition *Ast) *InlineVariable {
return NewInlineVariableWithAlias(name, name, definition)
}
// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression.
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
}
// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
//
// If a variable occurs one time, the variable is replaced by the inline definition. If the
// variable occurs more than once, the variable occurences are replaced by a cel.bind() call.
func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer {
return &inliningOptimizer{variables: inlineVars}
}
type inliningOptimizer struct {
variables []*InlineVariable
}
func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
for _, inlineVar := range opt.variables {
matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name()))
// Skip cases where the variable isn't in the expression graph
if len(matches) == 0 {
continue
}
// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
}
// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = root
lcaAncestorCount := 0
ancestors := map[int64]int{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
parent, found := match, true
for found {
ancestorCount, hasAncestor := ancestors[parent.ID()]
if !hasAncestor {
ancestors[parent.ID()] = 1
parent, found = parent.Parent()
continue
}
if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) {
lca = parent
lcaAncestorCount = ancestorCount
}
ancestors[parent.ID()] = ancestorCount + 1
parent, found = parent.Parent()
}
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
}
// Copy the inlined AST expr and source info.
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.SetMacroCall(lca.ID(), bindMacro)
}
return a
}
// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
ctx.UpdateExpr(prev, inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
ctx.UpdateExpr(prev, inlined)
}
}
// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe
// expression appropriate for the inlined type, if possible.
//
// If the rewrite is not possible an error is reported at the inline expression site.
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
if inlined.Kind() == ast.SelectKind {
presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.SetMacroCall(prev.ID(), hasMacro)
return
}
ctx.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
))
return
}
ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType)
}
// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression
// being replaced occurs within a presence test. Value types with a size() method or field selection
// support can be bound.
//
// In future iterations, support may also be added for indexer types which can be rewritten as an `in`
// expression; however, this would imply a rewrite of the inlined expression that may not be necessary
// in most cases.
func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool {
if inlinedType.IsAssignableType(NullType) ||
inlinedType.HasTrait(traits.SizerType) {
return true
}
for _, m := range matches {
if m.Kind() != ast.SelectKind {
continue
}
sel := m.AsSelect()
if sel.IsTestOnly() {
return false
}
}
return true
}
// matchVariable matches simple identifiers, select expressions, and presence test expressions
// which match the (potentially) qualified variable name provided as input.
//
// Note, this function does not support inlining against select expressions which includes optional
// field selection. This may be a future refinement.
func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
return true
}
if e.Kind() == ast.SelectKind {
sel := e.AsSelect()
// While the `ToQualifiedName` call could take the select directly, this
// would skip presence tests from possible matches, which we would like
// to include.
qualName, found := containers.ToQualifiedName(sel.Operand())
return found && qualName+"."+sel.FieldName() == varName
}
return false
}
}

View File

@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
checked, err := ast.ToAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
return &Ast{source: src, impl: checked}, nil
}
// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@ -67,13 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}
return ast.CheckedASTToCheckedExpr(cAst)
return ast.ToProto(a.impl)
}
// ParsedExprToAst converts a parsed expression proto message to an Ast.
@ -89,18 +77,12 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast {
//
// Prefer ParsedExprToAst if loading expressions from storage.
func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast {
si := parsedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo())
if src == nil {
src = common.NewInfoSource(si)
}
return &Ast{
expr: parsedExpr.GetExpr(),
info: si,
source: src,
src = common.NewInfoSource(parsedExpr.GetSourceInfo())
}
e, _ := ast.ProtoToExpr(parsedExpr.GetExpr())
return &Ast{source: src, impl: ast.NewAST(e, info)}
}
// AstToParsedExpr converts an Ast to an protobuf ParsedExpr value.
@ -116,9 +98,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
expr := a.Expr()
info := a.SourceInfo()
return parser.Unparse(expr, info)
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
}
// RefValueToValue converts between ref.Val and api.expr.Value.

View File

@ -20,6 +20,7 @@ import (
"strings"
"time"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
@ -28,8 +29,6 @@ import (
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Types(types.OptionalType),
// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
Macros(ReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption {
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewCall(optionalOfFunc,
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}
func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}

View File

@ -15,6 +15,11 @@
package cel
import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -26,7 +31,14 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value.
type MacroFactory = parser.MacroExpander
// MacroExprFactory assists with the creation of Expr values in a manner which is consistent
// the internal semantics and id generation behaviors of the parser and checker libraries.
type MacroExprFactory = parser.ExprHelper
// MacroExpander converts a call and its associated arguments into a protobuf Expr representation.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
@ -36,48 +48,197 @@ type Macro = parser.Macro
// and produces as output an Expr ast node.
//
// Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil.
type MacroExpander = parser.MacroExpander
type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error)
// MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree.
type MacroExprHelper = parser.ExprHelper
// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is
// consistent with the source position and expression id generation code leveraged by both
// the parser and type-checker.
type MacroExprHelper interface {
// Copy the input expression with a brand new set of identifiers.
Copy(*exprpb.Expr) *exprpb.Expr
// LiteralBool creates an Expr value for a bool literal.
LiteralBool(value bool) *exprpb.Expr
// LiteralBytes creates an Expr value for a byte literal.
LiteralBytes(value []byte) *exprpb.Expr
// LiteralDouble creates an Expr value for double literal.
LiteralDouble(value float64) *exprpb.Expr
// LiteralInt creates an Expr value for an int literal.
LiteralInt(value int64) *exprpb.Expr
// LiteralString creates am Expr value for a string literal.
LiteralString(value string) *exprpb.Expr
// LiteralUint creates an Expr value for a uint literal.
LiteralUint(value uint64) *exprpb.Expr
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
NewList(elems ...*exprpb.Expr) *exprpb.Expr
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewMapEntry creates a Map Entry for the key, value pair.
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr
// Ident creates an identifier Expr value.
Ident(name string) *exprpb.Expr
// AccuIdent returns an accumulator identifier for use with comprehension results.
AccuIdent() *exprpb.Expr
// GlobalCall creates a function call Expr value for a global (free) function.
GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr
// ReceiverCall creates a function call Expr value for a receiver-style function.
ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr
// Select create a field traversal Expr value.
Select(operand *exprpb.Expr, field string) *exprpb.Expr
// OffsetLocation returns the Location of the expression identifier.
OffsetLocation(exprID int64) common.Location
// NewError associates an error message with a given expression id.
NewError(exprID int64, message string) *Error
}
// GlobalMacro creates a Macro for a global function with the specified arg count.
func GlobalMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewGlobalMacro(function, argCount, factory)
}
// ReceiverMacro creates a Macro for a receiver function matching the specified arg count.
func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewReceiverMacro(function, argCount, factory)
}
// GlobalVarArgMacro creates a Macro for a global function with a variable arg count.
func GlobalVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewGlobalVarArgMacro(function, factory)
}
// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
func ReceiverVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewReceiverVarArgMacro(function, factory)
}
// NewGlobalMacro creates a Macro for a global function with the specified arg count.
//
// Deprecated: use GlobalMacro
func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewGlobalMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalMacro(function, argCount, expand.Expander)
}
// NewReceiverMacro creates a Macro for a receiver function matching the specified arg count.
//
// Deprecated: use ReceiverMacro
func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro {
return parser.NewReceiverMacro(function, argCount, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverMacro(function, argCount, expand.Expander)
}
// NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count.
//
// Deprecated: use GlobalVarArgMacro
func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewGlobalVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewGlobalVarArgMacro(function, expand.Expander)
}
// NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
//
// Deprecated: use ReceiverVarArgMacro
func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
return parser.NewReceiverVarArgMacro(function, expander)
expand := adaptingExpander{expander}
return parser.NewReceiverVarArgMacro(function, expand.Expander)
}
// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field)
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeHas(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
arg, err := adaptToExpr(args[0])
if err != nil {
return nil, err
}
if arg.Kind() == ast.SelectKind {
s := arg.AsSelect()
return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName()))
}
return nil, ph.NewError(arg.ID(), "invalid argument to has() macro")
}
// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExists(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExistsOne(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the
@ -91,14 +252,30 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeMap(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeFilter(meh, target, args)
ph, err := toParserHelper(meh)
if err != nil {
return nil, err
}
out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args))
if err != nil {
return nil, err
}
return adaptToProto(out)
}
var (
@ -142,3 +319,258 @@ var (
// NoMacros provides an alias to an empty list of macros
NoMacros = []Macro{}
)
type adaptingExpander struct {
legacyExpander MacroExpander
}
func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
var legacyTarget *exprpb.Expr = nil
var err *Error = nil
if target != nil {
legacyTarget, err = adaptToProto(target)
if err != nil {
return nil, err
}
}
legacyArgs := make([]*exprpb.Expr, len(args))
for i, arg := range args {
legacyArgs[i], err = adaptToProto(arg)
if err != nil {
return nil, err
}
}
ah := &adaptingHelper{modernHelper: eh}
legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs)
if err != nil {
return nil, err
}
ex, err := adaptToExpr(legacyExpr)
if err != nil {
return nil, err
}
return ex, nil
}
func wrapErr(id int64, message string, err error) *common.Error {
return &common.Error{
Location: common.NoLocation,
Message: fmt.Sprintf("%s: %v", message, err),
ExprID: id,
}
}
type adaptingHelper struct {
modernHelper parser.ExprHelper
}
// Copy the input expression with a brand new set of identifiers.
func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e)))
}
// LiteralBool creates an Expr value for a bool literal.
func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value)))
}
// LiteralBytes creates an Expr value for a byte literal.
func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value)))
}
// LiteralDouble creates an Expr value for double literal.
func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value)))
}
// LiteralInt creates an Expr value for an int literal.
func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value)))
}
// LiteralString creates am Expr value for a string literal.
func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value)))
}
// LiteralUint creates an Expr value for a uint literal.
func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value)))
}
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...))
}
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(entries))
for i, e := range entries {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...))
}
// NewMapEntry creates a Map Entry for the key, value pair.
func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional))
}
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
adaptedEntries := make([]ast.EntryExpr, len(fieldInits))
for i, e := range fieldInits {
adaptedEntries[i] = mustAdaptToEntryExpr(e)
}
return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...))
}
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return mustAdaptToProtoEntry(
ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional))
}
// Fold creates a fold comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
func (ah *adaptingHelper) Fold(iterVar string,
iterRange *exprpb.Expr,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewComprehension(
mustAdaptToExpr(iterRange),
iterVar,
accuVar,
mustAdaptToExpr(accuInit),
mustAdaptToExpr(condition),
mustAdaptToExpr(step),
mustAdaptToExpr(result),
),
)
}
// Ident creates an identifier Expr value.
func (ah *adaptingHelper) Ident(name string) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewIdent(name))
}
// AccuIdent returns an accumulator identifier for use with comprehension results.
func (ah *adaptingHelper) AccuIdent() *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewAccuIdent())
}
// GlobalCall creates a function call Expr value for a global (free) function.
func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...))
}
// ReceiverCall creates a function call Expr value for a receiver-style function.
func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
return mustAdaptToProto(
ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...))
}
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field))
}
// Select create a field traversal Expr value.
func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr {
op := mustAdaptToExpr(operand)
return mustAdaptToProto(ah.modernHelper.NewSelect(op, field))
}
// OffsetLocation returns the Location of the expression identifier.
func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location {
return ah.modernHelper.OffsetLocation(exprID)
}
// NewError associates an error message with a given expression id.
func (ah *adaptingHelper) NewError(exprID int64, message string) *Error {
return ah.modernHelper.NewError(exprID, message)
}
func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr {
adapted := make([]ast.Expr, len(exprs))
for i, e := range exprs {
adapted[i] = mustAdaptToExpr(e)
}
return adapted
}
func mustAdaptToExpr(e *exprpb.Expr) ast.Expr {
out, _ := adaptToExpr(e)
return out
}
func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ProtoToExpr(e)
if err != nil {
return nil, wrapErr(e.GetId(), "proto conversion failure", err)
}
return out, nil
}
func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr {
out, _ := ast.ProtoToEntryExpr(e)
return out
}
func mustAdaptToProto(e ast.Expr) *exprpb.Expr {
out, _ := adaptToProto(e)
return out
}
func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) {
if e == nil {
return nil, nil
}
out, err := ast.ExprToProto(e)
if err != nil {
return nil, wrapErr(e.ID(), "expr conversion failure", err)
}
return out, nil
}
func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry {
out, _ := ast.EntryExprToProto(e)
return out
}
func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) {
ah, ok := meh.(*adaptingHelper)
if !ok {
return nil, common.NewError(0,
fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh),
common.NoLocation)
}
return ah.modernHelper, nil
}

509
vendor/github.com/google/cel-go/cel/optimizer.go generated vendored Normal file
View File

@ -0,0 +1,509 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order.
//
// The static optimizer normalizes expression ids and type-checking run between optimization
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
}
// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
}
}
// Optimize applies a sequence of optimizations to an Ast within a given environment.
//
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)
ids := newIDGenerator(ast.MaxID(a.impl))
// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
baseFac := ast.NewExprFactory()
exprFac := &optimizerExprFactory{
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
}
ctx := &OptimizerContext{
optimizerExprFactory: exprFac,
Env: env,
Issues: issues,
}
// Apply the optimizations sequentially.
for _, o := range opt.optimizers {
optimized = o.Optimize(ctx, optimized)
if issues.Err() != nil {
return nil, issues
}
// Normalize expression id metadata including coordination with macro call metadata.
freshIDGen := newIDGenerator(0)
info := optimized.SourceInfo()
expr := optimized.Expr()
normalizeIDs(freshIDGen.renumberStable, expr, info)
cleanupMacroRefs(expr, info)
// Recheck the updated expression for any possible type-agreement or validation errors.
parsed := &Ast{
source: a.Source(),
impl: ast.NewAST(expr, info)}
checked, iss := ctx.Check(parsed)
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
}
// Return the optimized result.
return &Ast{
source: a.Source(),
impl: optimized,
}, nil
}
// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)
if len(info.MacroCalls()) == 0 {
return
}
// First, update the macro call ids themselves.
callIDMap := map[int64]int64{}
for id := range info.MacroCalls() {
callIDMap[id] = idGen(id)
}
// Then update the macro call definitions which refer to these ids, but
// ensure that the updates don't collide and remove macro entries which haven't
// been visited / updated yet.
type macroUpdate struct {
id int64
call ast.Expr
}
macroUpdates := []macroUpdate{}
for oldID, newID := range callIDMap {
call, found := info.GetMacroCall(oldID)
if !found {
continue
}
call.RenumberIDs(idGen)
macroUpdates = append(macroUpdates, macroUpdate{id: newID, call: call})
info.ClearMacroCall(oldID)
}
for _, u := range macroUpdates {
info.SetMacroCall(u.id, u.call)
}
}
func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
if len(info.MacroCalls()) == 0 {
return
}
// Sanitize the macro call references once the optimized expression has been computed
// and the ids normalized between the expression and the macros.
exprRefMap := make(map[int64]struct{})
ast.PostOrderVisit(expr, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
// Update the macro call id references to ensure that macro pointers are
// updated consistently across macros.
for _, call := range info.MacroCalls() {
ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) {
if e.ID() == 0 {
return
}
exprRefMap[e.ID()] = struct{}{}
}))
}
for id := range info.MacroCalls() {
if _, found := exprRefMap[id]; !found {
info.ClearMacroCall(id)
}
}
}
// newIDGenerator ensures that new ids are only created the first time they are encountered.
func newIDGenerator(seed int64) *idGenerator {
return &idGenerator{
idMap: make(map[int64]int64),
seed: seed,
}
}
type idGenerator struct {
idMap map[int64]int64
seed int64
}
func (gen *idGenerator) nextID() int64 {
gen.seed++
return gen.seed
}
func (gen *idGenerator) renumberStable(id int64) int64 {
if id == 0 {
return 0
}
if newID, found := gen.idMap[id]; found {
return newID
}
nextID := gen.nextID()
gen.idMap[id] = nextID
return nextID
}
// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate
// subexpressions and report any errors encountered along the way. The context also embeds the
// optimizerExprFactory which can be used to generate new sub-expressions with expression ids
// consistent with the expectations of a parsed expression.
type OptimizerContext struct {
*Env
*optimizerExprFactory
*Issues
}
// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Optimize(*OptimizerContext, *ast.AST) *ast.AST
}
type optimizerExprFactory struct {
*idGenerator
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
}
// NewAST creates an AST from the current expression using the tracked source info which
// is modified and managed by the OptimizerContext.
func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST {
return ast.NewAST(expr, opt.sourceInfo)
}
// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the
// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions
// between copies.
//
// Use this method before attempting to merge the expression from AST into another.
func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) {
idGen := newIDGenerator(opt.nextID())
defer func() { opt.seed = idGen.nextID() }()
copyExpr := opt.fac.CopyExpr(a.Expr())
copyInfo := ast.CopySourceInfo(a.SourceInfo())
normalizeIDs(idGen.renumberStable, copyExpr, copyInfo)
return copyExpr, copyInfo
}
// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
copyExpr, copyInfo := opt.CopyAST(a)
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
return copyExpr
}
// ClearMacroCall clears the macro at the given expression id.
func (opt *optimizerExprFactory) ClearMacroCall(id int64) {
opt.sourceInfo.ClearMacroCall(id)
}
// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info
// metadata.
func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}
// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
varID := opt.nextID()
remainingID := opt.nextID()
remaining = opt.fac.CopyExpr(remaining)
remaining.RenumberIDs(func(id int64) int64 {
if id == macroID {
return remainingID
}
return id
})
if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists {
opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
}
astExpr = opt.fac.NewComprehension(macroID,
opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}),
"#unused",
varName,
opt.fac.CopyExpr(varInit),
opt.fac.NewLiteral(opt.nextID(), types.False),
opt.fac.NewIdent(varID, varName),
remaining)
macroExpr = opt.fac.NewMemberCall(0, "bind",
opt.fac.NewIdent(opt.nextID(), "cel"),
opt.fac.NewIdent(varID, varName),
opt.fac.CopyExpr(varInit),
opt.fac.CopyExpr(remaining))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewCall creates a global function call invocation expression.
//
// Example:
//
// countByField(list, fieldName)
// - function: countByField
// - args: [list, fieldName]
func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr {
return opt.fac.NewCall(opt.nextID(), function, args...)
}
// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call.
//
// Example:
//
// list.countByField(fieldName)
// - function: countByField
// - target: list
// - args: [fieldName]
func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return opt.fac.NewMemberCall(opt.nextID(), function, target, args...)
}
// NewIdent creates a new identifier expression.
//
// Examples:
//
// - simple_var_name
// - qualified.subpackage.var_name
func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr {
return opt.fac.NewIdent(opt.nextID(), name)
}
// NewLiteral creates a new literal expression value.
//
// The range of valid values for a literal generated during optimization is different than for expressions
// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can
// be converted back to a literal-like form.
func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr {
return opt.fac.NewLiteral(opt.nextID(), value)
}
// NewList creates a list expression with a set of optional indices.
//
// Examples:
//
// [a, b]
// - elems: [a, b]
// - optIndices: []
//
// [a, ?b, ?c]
// - elems: [a, b, c]
// - optIndices: [1, 2]
func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr {
return opt.fac.NewList(opt.nextID(), elems, optIndices)
}
// NewMap creates a map from a set of entry expressions which contain a key and value expression.
func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr {
return opt.fac.NewMap(opt.nextID(), entries)
}
// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the
// entry is optional.
//
// Examples:
//
// {a: b}
// - key: a
// - value: b
// - optional: false
//
// {?a: ?b}
// - key: a
// - value: b
// - optional: true
func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional)
}
// NewHasMacro generates a test-only select expression to be included within an AST and an unexpanded
// has() macro call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewHasMacro(macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) {
sel := s.AsSelect()
astExpr = opt.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName())
macroExpr = opt.fac.NewCall(0, "has",
opt.NewSelect(opt.fac.CopyExpr(sel.Operand()), sel.FieldName()))
opt.sanitizeMacro(macroID, macroExpr)
return
}
// NewSelect creates a select expression where a field value is selected from an operand.
//
// Example:
//
// msg.field_name
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr {
return opt.fac.NewSelect(opt.nextID(), operand, field)
}
// NewStruct creates a new typed struct value with an set of field initializations.
//
// Example:
//
// pkg.TypeName{field: value}
// - typeName: pkg.TypeName
// - fields: [{field: value}]
func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr {
return opt.fac.NewStruct(opt.nextID(), typeName, fields)
}
// NewStructField creates a struct field initialization.
//
// Examples:
//
// {count: 3u}
// - field: count
// - value: 3u
// - optional: false
//
// {?count: x}
// - field: count
// - value: x
// - optional: true
func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewStructField(opt.nextID(), field, value, isOptional)
}
// UpdateExpr updates the target expression with the updated content while preserving macro metadata.
//
// There are four scenarios during the update to consider:
// 1. target is not macro, updated is not macro
// 2. target is macro, updated is not macro
// 3. target is macro, updated is macro
// 4. target is not macro, updated is macro
//
// When the target is a macro already, it may either be updated to a new macro function
// body if the update is also a macro, or it may be removed altogether if the update is
// a macro.
//
// When the update is a macro, then the target references within other macros must be
// updated to point to the new updated macro. Otherwise, other macros which pointed to
// the target body must be replaced with copies of the updated expression body.
func (opt *optimizerExprFactory) UpdateExpr(target, updated ast.Expr) {
// Update the expression
target.SetKindCase(updated)
// Early return if there's no macros present sa the source info reflects the
// macro set from the target and updated expressions.
if len(opt.sourceInfo.MacroCalls()) == 0 {
return
}
// Determine whether the target expression was a macro.
_, targetIsMacro := opt.sourceInfo.GetMacroCall(target.ID())
// Determine whether the updated expression was a macro.
updatedMacro, updatedIsMacro := opt.sourceInfo.GetMacroCall(updated.ID())
if updatedIsMacro {
// If the updated call was a macro, then updated id maps to target id,
// and the updated macro moves into the target id slot.
opt.sourceInfo.ClearMacroCall(updated.ID())
opt.sourceInfo.SetMacroCall(target.ID(), updatedMacro)
} else if targetIsMacro {
// Otherwise if the target expr was a macro, but is no longer, clear
// the macro reference.
opt.sourceInfo.ClearMacroCall(target.ID())
}
// Punch holes in the updated value where macros references exist.
macroExpr := opt.fac.CopyExpr(target)
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
// Update any references to the expression within a macro
macroVisitor := ast.NewExprVisitor(func(call ast.Expr) {
// Update the target expression to point to the macro expression which
// will be empty if the updated expression was a macro.
if call.ID() == target.ID() {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Update the macro call expression if it refers to the updated expression
// id which has since been remapped to the target id.
if call.ID() == updated.ID() {
// Either ensure the expression is a macro reference or a populated with
// the relevant sub-expression if the updated expr was not a macro.
if updatedIsMacro {
call.SetKindCase(nil)
} else {
call.SetKindCase(opt.fac.CopyExpr(macroExpr))
}
// Since SetKindCase does not renumber the id, ensure the references to
// the old 'updated' id are mapped to the target id.
call.RenumberIDs(func(id int64) int64 {
if id == updated.ID() {
return target.ID()
}
return id
})
}
})
for _, call := range opt.sourceInfo.MacroCalls() {
ast.PostOrderVisit(call, macroVisitor)
}
}
func (opt *optimizerExprFactory) sanitizeMacro(macroID int64, macroExpr ast.Expr) {
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
}

View File

@ -448,6 +448,8 @@ const (
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.StringsValidateFormatCalls() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

View File

@ -19,7 +19,6 @@ import (
"fmt"
"sync"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@ -152,7 +151,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()
@ -213,34 +212,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@ -274,33 +245,16 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}
return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}
func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
if err != nil {
return nil, err
}
@ -580,8 +534,6 @@ func (p *evalActivationPool) Put(value any) {
}
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

View File

@ -21,8 +21,6 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
@ -69,7 +67,7 @@ type ASTValidator interface {
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.NavigableExpr) error
type argChecker func(env *Env, call, arg ast.Expr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
@ -203,8 +201,8 @@ func (v formatValidator) Name() string {
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
@ -221,8 +219,8 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST,
}
}
func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
func evalCall(env *Env, call, arg ast.Expr) error {
ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
prg, err := env.Program(ast)
if err != nil {
return err
@ -231,7 +229,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error {
return err
}
func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
func compileRegex(_ *Env, _, arg ast.Expr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
@ -244,25 +242,14 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
root := ast.NavigateAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
@ -273,7 +260,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
et := a.GetType(e.ID())
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
@ -296,9 +283,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
mapEntry := e.AsMapEntry()
key, val := mapEntry.Key(), mapEntry.Value()
kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
if mapEntry.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
@ -316,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
@ -325,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
parent, found = parent.Parent()
}
return false
}
@ -353,8 +340,8 @@ func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return

View File

@ -60,7 +60,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

View File

@ -18,6 +18,7 @@ package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
@ -25,139 +26,98 @@ import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types/ref"
)
type checker struct {
*ast.AST
ast.ExprFactory
env *Env
errors *typeErrors
mappings *mapping
freeTypeVarCounter int
sourceInfo *exprpb.SourceInfo
types map[int64]*types.Type
references map[int64]*ast.ReferenceInfo
}
// Check performs type checking, giving a typed AST.
// The input is a ParsedExpr proto and an env which encapsulates
// type binding of variables, declarations of built-in functions,
// descriptions of protocol buffers, and a registry for errors.
// Returns a CheckedExpr proto, which might not be usable if
// there are errors in the error registry.
func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) {
//
// The input is a parsed AST and an env which encapsulates type binding of variables,
// declarations of built-in functions, descriptions of protocol buffers, and a registry for
// errors.
//
// Returns a type-checked AST, which might not be usable if there are errors in the error
// registry.
func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
typeMap := make(map[int64]*types.Type)
refMap := make(map[int64]*ast.ReferenceInfo)
c := checker{
AST: ast.NewCheckedAST(parsed, typeMap, refMap),
ExprFactory: ast.NewExprFactory(),
env: env,
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
sourceInfo: parsedExpr.GetSourceInfo(),
types: make(map[int64]*types.Type),
references: make(map[int64]*ast.ReferenceInfo),
}
c.check(parsedExpr.GetExpr())
c.check(c.Expr())
// Walk over the final type map substituting any type parameters either by their bound value or
// by DYN.
m := make(map[int64]*types.Type)
for id, t := range c.types {
m[id] = substitute(c.mappings, t, true)
// Walk over the final type map substituting any type parameters either by their bound value
// or by DYN.
for id, t := range c.TypeMap() {
c.SetType(id, substitute(c.mappings, t, true))
}
return &ast.CheckedAST{
Expr: parsedExpr.GetExpr(),
SourceInfo: parsedExpr.GetSourceInfo(),
TypeMap: m,
ReferenceMap: c.references,
}, errs
return c.AST, errs
}
func (c *checker) check(e *exprpb.Expr) {
func (c *checker) check(e ast.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
c.checkBoolLiteral(e)
case *exprpb.Constant_BytesValue:
c.checkBytesLiteral(e)
case *exprpb.Constant_DoubleValue:
c.checkDoubleLiteral(e)
case *exprpb.Constant_Int64Value:
c.checkInt64Literal(e)
case *exprpb.Constant_NullValue:
c.checkNullLiteral(e)
case *exprpb.Constant_StringValue:
c.checkStringLiteral(e)
case *exprpb.Constant_Uint64Value:
c.checkUint64Literal(e)
switch e.Kind() {
case ast.LiteralKind:
literal := ref.Val(e.AsLiteral())
switch literal.Type() {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
c.setType(e, literal.Type().(*types.Type))
default:
c.errors.unexpectedASTType(e.ID(), c.location(e), "literal", literal.Type().TypeName())
}
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
c.checkIdent(e)
case *exprpb.Expr_SelectExpr:
case ast.SelectKind:
c.checkSelect(e)
case *exprpb.Expr_CallExpr:
case ast.CallKind:
c.checkCall(e)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
c.checkCreateList(e)
case *exprpb.Expr_StructExpr:
case ast.MapKind:
c.checkCreateMap(e)
case ast.StructKind:
c.checkCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
case ast.ComprehensionKind:
c.checkComprehension(e)
default:
c.errors.unexpectedASTType(e.GetId(), c.location(e), e)
c.errors.unexpectedASTType(e.ID(), c.location(e), "unspecified", reflect.TypeOf(e).Name())
}
}
func (c *checker) checkInt64Literal(e *exprpb.Expr) {
c.setType(e, types.IntType)
}
func (c *checker) checkUint64Literal(e *exprpb.Expr) {
c.setType(e, types.UintType)
}
func (c *checker) checkStringLiteral(e *exprpb.Expr) {
c.setType(e, types.StringType)
}
func (c *checker) checkBytesLiteral(e *exprpb.Expr) {
c.setType(e, types.BytesType)
}
func (c *checker) checkDoubleLiteral(e *exprpb.Expr) {
c.setType(e, types.DoubleType)
}
func (c *checker) checkBoolLiteral(e *exprpb.Expr) {
c.setType(e, types.BoolType)
}
func (c *checker) checkNullLiteral(e *exprpb.Expr) {
c.setType(e, types.NullType)
}
func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
func (c *checker) checkIdent(e ast.Expr) {
identName := e.AsIdent()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
if ident := c.env.LookupIdent(identName); ident != nil {
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.Name()
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName())
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), identName)
}
func (c *checker) checkSelect(e *exprpb.Expr) {
sel := e.GetSelectExpr()
func (c *checker) checkSelect(e ast.Expr) {
sel := e.AsSelect()
// Before traversing down the tree, try to interpret as qualified name.
qname, found := containers.ToQualifiedName(e)
if found {
@ -170,31 +130,26 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// variable name.
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
identName := ident.Name()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: identName,
},
}
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
}
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType := c.checkSelectField(e, sel.Operand(), sel.FieldName(), false)
if sel.IsTestOnly() {
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
func (c *checker) checkOptSelect(e ast.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
call := e.AsCall()
operand := call.Args()[0]
field := call.Args()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field)
c.errors.notAnOptionalFieldSelection(field.ID(), c.location(field), field)
return
}
@ -204,7 +159,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) {
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type {
func (c *checker) checkSelectField(e, operand ast.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
@ -222,7 +177,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found {
if fieldType, found := c.lookupFieldType(e.ID(), messageType.TypeName(), field); found {
resultType = fieldType
}
case types.TypeParamKind:
@ -236,7 +191,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType)
c.errors.typeDoesNotSupportFieldSelection(e.ID(), c.location(e), targetType)
}
resultType = types.DynType
}
@ -248,35 +203,34 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
return resultType
}
func (c *checker) checkCall(e *exprpb.Expr) {
func (c *checker) checkCall(e ast.Expr) {
// Note: similar logic exists within the `interpreter/planner.go`. If making changes here
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.GetCallExpr()
fnName := call.GetFunction()
call := e.AsCall()
fnName := call.FunctionName()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
args := call.Args()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
target := call.GetTarget()
// Regular static call with simple name.
if target == nil {
if !call.IsMemberFunction() {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
call.Function = fn.Name()
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
// Check to see whether the overload resolves.
c.resolveOverloadOrError(e, fn, nil, args)
return
@ -287,6 +241,7 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// target a.b.
//
// Check whether the target is a namespaced function name.
target := call.Target()
qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target)
if maybeQualified {
maybeQualifiedName := qualifiedPrefix + "." + fnName
@ -295,15 +250,14 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// The function name is namespaced and so preserving the target operand would
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
call.Target = nil
call.Function = fn.Name()
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
// Regular instance call.
c.check(call.Target)
c.check(target)
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
@ -312,11 +266,11 @@ func (c *checker) checkCall(e *exprpb.Expr) {
}
// Function name not declared, record error.
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) {
e ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
@ -330,7 +284,7 @@ func (c *checker) resolveOverloadOrError(
}
func (c *checker) resolveOverload(
call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
call ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) *overloadResolution {
var argTypes []*types.Type
if target != nil {
@ -362,8 +316,8 @@ func (c *checker) resolveOverload(
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].GetId(),
c.locationByID(args[i].GetId()),
args[i].ID(),
c.locationByID(args[i].ID()),
types.BoolType,
argType)
resultType = types.ErrorType
@ -408,29 +362,29 @@ func (c *checker) resolveOverload(
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil)
c.errors.noMatchingOverload(call.ID(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
return newResolution(checkedRef, resultType)
}
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
func (c *checker) checkCreateList(e ast.Expr) {
create := e.AsList()
var elemsType *types.Type
optionalIndices := create.GetOptionalIndices()
optionalIndices := create.OptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
for i, e := range create.Elements() {
c.check(e)
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType)
c.errors.typeMismatch(e.ID(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(e, elemsType, elemType)
@ -442,32 +396,24 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
str := e.GetStructExpr()
if str.GetMessageName() != "" {
c.checkCreateMessage(e)
} else {
c.checkCreateMap(e)
}
}
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
func (c *checker) checkCreateMap(e ast.Expr) {
mapVal := e.AsMap()
var mapKeyType *types.Type
var mapValueType *types.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
for _, e := range mapVal.Entries() {
entry := e.AsMapEntry()
key := entry.Key()
c.check(key)
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
val := ent.GetValue()
val := entry.Value()
c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
if entry.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType)
c.errors.typeMismatch(val.ID(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(val, mapValueType, valType)
@ -480,25 +426,28 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) {
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
func (c *checker) checkCreateStruct(e ast.Expr) {
msgVal := e.AsStruct()
// Determine the type of the message.
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.GetMessageName())
ident := c.env.LookupIdent(msgVal.TypeName())
if ident == nil {
c.errors.undeclaredReference(
e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName())
e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
typeName := ident.Name()
msgVal.MessageName = typeName
c.setReference(e, ast.NewIdentReference(ident.Name(), nil))
if msgVal.TypeName() != typeName {
e.SetKindCase(c.NewStruct(e.ID(), typeName, msgVal.Fields()))
msgVal = e.AsStruct()
}
c.setReference(e, ast.NewIdentReference(typeName, nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName())
c.errors.notAType(e.ID(), c.location(e), ident.Type().DeclaredTypeName())
} else {
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
@ -509,7 +458,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName())
c.errors.notAMessageType(e.ID(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
@ -517,37 +466,38 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
c.setType(e, resultType)
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
field := ent.GetFieldKey()
value := ent.GetValue()
for _, f := range msgVal.Fields() {
field := f.AsStructField()
fieldName := field.Name()
value := field.Value()
c.check(value)
fieldType := types.ErrorType
ft, found := c.lookupFieldType(ent.GetId(), typeName, field)
ft, found := c.lookupFieldType(f.ID(), typeName, fieldName)
if found {
fieldType = ft
}
valType := c.getType(value)
if ent.GetOptionalEntry() {
if field.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType)
c.errors.typeMismatch(value.ID(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType)
c.errors.fieldTypeMismatch(f.ID(), c.locationByID(f.ID()), fieldName, fieldType, valType)
}
}
}
func (c *checker) checkComprehension(e *exprpb.Expr) {
comp := e.GetComprehensionExpr()
c.check(comp.GetIterRange())
c.check(comp.GetAccuInit())
accuType := c.getType(comp.GetAccuInit())
rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false)
func (c *checker) checkComprehension(e ast.Expr) {
comp := e.AsComprehension()
c.check(comp.IterRange())
c.check(comp.AccuInit())
accuType := c.getType(comp.AccuInit())
rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false)
var varType *types.Type
switch rangeType.Kind() {
@ -564,32 +514,32 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
// Set the range iteration variable to type DYN as well.
varType = types.DynType
default:
c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType)
c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType))
c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType))
c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.GetLoopCondition())
c.assertType(comp.GetLoopCondition(), types.BoolType)
c.check(comp.GetLoopStep())
c.assertType(comp.GetLoopStep(), accuType)
c.check(comp.LoopCondition())
c.assertType(comp.LoopCondition(), types.BoolType)
c.check(comp.LoopStep())
c.assertType(comp.LoopStep(), accuType)
// Exit the loop's block scope before checking the result.
c.env = c.env.exitScope()
c.check(comp.GetResult())
c.check(comp.Result())
// Exit the comprehension scope.
c.env = c.env.exitScope()
c.setType(e, substitute(c.mappings, c.getType(comp.GetResult()), false))
c.setType(e, substitute(c.mappings, c.getType(comp.Result()), false))
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type {
func (c *checker) joinTypes(e ast.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
@ -599,7 +549,7 @@ func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *type
if c.dynAggregateLiteralElementTypesEnabled() {
return types.DynType
}
c.errors.typeMismatch(e.GetId(), c.location(e), previous, current)
c.errors.typeMismatch(e.ID(), c.location(e), previous, current)
return types.ErrorType
}
@ -633,41 +583,41 @@ func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
return false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
func maybeUnwrapString(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.LiteralKind:
literal := e.AsLiteral()
switch v := literal.(type) {
case types.String:
return string(v), true
}
}
return "", false
}
func (c *checker) setType(e *exprpb.Expr, t *types.Type) {
if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t)
func (c *checker) setType(e ast.Expr, t *types.Type) {
if old, found := c.TypeMap()[e.ID()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.ID(), c.location(e), e, old, t)
return
}
c.types[e.GetId()] = t
c.SetType(e.ID(), t)
}
func (c *checker) getType(e *exprpb.Expr) *types.Type {
return c.types[e.GetId()]
func (c *checker) getType(e ast.Expr) *types.Type {
return c.TypeMap()[e.ID()]
}
func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) {
if old, found := c.references[e.GetId()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r)
func (c *checker) setReference(e ast.Expr, r *ast.ReferenceInfo) {
if old, found := c.ReferenceMap()[e.ID()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.ID(), c.location(e), e, old, r)
return
}
c.references[e.GetId()] = r
c.SetReference(e.ID(), r)
}
func (c *checker) assertType(e *exprpb.Expr, t *types.Type) {
func (c *checker) assertType(e ast.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e))
c.errors.typeMismatch(e.ID(), c.location(e), t, c.getType(e))
}
}
@ -683,26 +633,12 @@ func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
}
}
func (c *checker) location(e *exprpb.Expr) common.Location {
return c.locationByID(e.GetId())
func (c *checker) location(e ast.Expr) common.Location {
return c.locationByID(e.ID())
}
func (c *checker) locationByID(id int64) common.Location {
positions := c.sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range c.sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
return c.SourceInfo().GetStartLocation(id)
}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {

View File

@ -22,8 +22,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go
@ -58,7 +56,7 @@ type AstNode interface {
// Type returns the deduced type of the AstNode.
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
Expr() ast.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no
@ -69,7 +67,7 @@ type AstNode interface {
type astNode struct {
path []string
t *types.Type
expr *exprpb.Expr
expr ast.Expr
derivedSize *SizeEstimate
}
@ -81,7 +79,7 @@ func (e astNode) Type() *types.Type {
return e.t
}
func (e astNode) Expr() *exprpb.Expr {
func (e astNode) Expr() ast.Expr {
return e.expr
}
@ -90,29 +88,27 @@ func (e astNode) ComputedSize() *SizeEstimate {
return e.derivedSize
}
var v uint64
switch ek := e.expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
switch e.expr.Kind() {
case ast.LiteralKind:
switch ck := e.expr.AsLiteral().(type) {
case types.String:
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
*exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value,
*exprpb.Constant_NullValue:
v = uint64(len([]rune(ck)))
case types.Bytes:
v = uint64(len(ck))
case types.Bool, types.Double, types.Duration,
types.Int, types.Timestamp, types.Uint,
types.Null:
v = uint64(1)
default:
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.GetElements()))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.GetMessageName() == "" {
v = uint64(len(ek.StructExpr.GetEntries()))
}
case ast.ListKind:
v = uint64(e.expr.AsList().Size())
case ast.MapKind:
v = uint64(e.expr.AsMap().Size())
default:
return nil
}
@ -265,7 +261,7 @@ type coster struct {
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
checkedAST *ast.AST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
@ -275,8 +271,8 @@ type coster struct {
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
type iterRangeScopes map[string][]int64
func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) {
vs[varName] = append(vs[varName], expr.GetId())
func (vs iterRangeScopes) push(varName string, expr ast.Expr) {
vs[varName] = append(vs[varName], expr.ID())
}
func (vs iterRangeScopes) pop(varName string) {
@ -324,9 +320,9 @@ func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) C
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checker,
checkedAST: checked,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
@ -340,28 +336,30 @@ func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption)
return CostEstimate{}, err
}
}
return c.cost(checker.Expr), nil
return c.cost(checked.Expr()), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
func (c *coster) cost(e ast.Expr) CostEstimate {
if e == nil {
return CostEstimate{}
}
var cost CostEstimate
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch e.Kind() {
case ast.LiteralKind:
cost = constCost
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
cost = c.costIdent(e)
case *exprpb.Expr_SelectExpr:
case ast.SelectKind:
cost = c.costSelect(e)
case *exprpb.Expr_CallExpr:
case ast.CallKind:
cost = c.costCall(e)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
cost = c.costCreateList(e)
case *exprpb.Expr_StructExpr:
case ast.MapKind:
cost = c.costCreateMap(e)
case ast.StructKind:
cost = c.costCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
case ast.ComprehensionKind:
cost = c.costComprehension(e)
default:
return CostEstimate{}
@ -369,53 +367,51 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate {
return cost
}
func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
identExpr := e.GetIdentExpr()
func (c *coster) costIdent(e ast.Expr) CostEstimate {
identName := e.AsIdent()
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedAST.TypeMap[iterRange].Kind() {
if iterRange, ok := c.iterRanges.peek(identName); ok {
switch c.checkedAST.GetType(iterRange).Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
c.addPath(e, []string{identExpr.GetName()})
c.addPath(e, []string{identName})
}
return selectAndIdentCost
}
func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
func (c *coster) costSelect(e ast.Expr) CostEstimate {
sel := e.AsSelect()
var sum CostEstimate
if sel.GetTestOnly() {
if sel.IsTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
sum = sum.Add(c.cost(sel.Operand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
sum = sum.Add(c.cost(sel.Operand()))
targetType := c.getType(sel.Operand())
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}
// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField()))
c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName()))
return sum
}
func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
func (c *coster) costCall(e ast.Expr) CostEstimate {
call := e.AsCall()
args := call.Args()
var sum CostEstimate
@ -426,22 +422,20 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedAST.ReferenceMap[e.GetId()]
if ref == nil || len(ref.OverloadIDs) == 0 {
overloadIDs := c.checkedAST.GetOverloadIDs(e.ID())
if len(overloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
if target != nil {
if call.Target != nil {
sum = sum.Add(c.cost(call.GetTarget()))
targetType = c.newAstNode(call.GetTarget())
}
if call.IsMemberFunction() {
sum = sum.Add(c.cost(call.Target()))
targetType = c.newAstNode(call.Target())
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.OverloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
for _, overload := range overloadIDs {
overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
@ -464,64 +458,56 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
}
}
if resultSize != nil {
c.computedSizes[e.GetId()] = *resultSize
c.computedSizes[e.ID()] = *resultSize
}
return sum.Add(fnCost)
}
func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate {
create := e.GetListExpr()
func (c *coster) costCreateList(e ast.Expr) CostEstimate {
create := e.AsList()
var sum CostEstimate
for _, e := range create.GetElements() {
for _, e := range create.Elements() {
sum = sum.Add(c.cost(e))
}
return sum.Add(createListBaseCost)
}
func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate {
str := e.GetStructExpr()
if str.MessageName != "" {
return c.costCreateMessage(e)
}
return c.costCreateMap(e)
}
func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate {
mapVal := e.GetStructExpr()
func (c *coster) costCreateMap(e ast.Expr) CostEstimate {
mapVal := e.AsMap()
var sum CostEstimate
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
sum = sum.Add(c.cost(key))
sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range mapVal.Entries() {
entry := ent.AsMapEntry()
sum = sum.Add(c.cost(entry.Key()))
sum = sum.Add(c.cost(entry.Value()))
}
return sum.Add(createMapBaseCost)
}
func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate {
msgVal := e.GetStructExpr()
func (c *coster) costCreateStruct(e ast.Expr) CostEstimate {
msgVal := e.AsStruct()
var sum CostEstimate
for _, ent := range msgVal.GetEntries() {
sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range msgVal.Fields() {
field := ent.AsStructField()
sum = sum.Add(c.cost(field.Value()))
}
return sum.Add(createMessageBaseCost)
}
func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
comp := e.GetComprehensionExpr()
func (c *coster) costComprehension(e ast.Expr) CostEstimate {
comp := e.AsComprehension()
var sum CostEstimate
sum = sum.Add(c.cost(comp.GetIterRange()))
sum = sum.Add(c.cost(comp.GetAccuInit()))
sum = sum.Add(c.cost(comp.IterRange()))
sum = sum.Add(c.cost(comp.AccuInit()))
// Track the iterRange of each IterVar for field path construction
c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange())
loopCost := c.cost(comp.GetLoopCondition())
stepCost := c.cost(comp.GetLoopStep())
c.iterRanges.pop(comp.GetIterVar())
sum = sum.Add(c.cost(comp.Result))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange()))
c.iterRanges.push(comp.IterVar(), comp.IterRange())
loopCost := c.cost(comp.LoopCondition())
stepCost := c.cost(comp.LoopStep())
c.iterRanges.pop(comp.IterVar())
sum = sum.Add(c.cost(comp.Result()))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange()))
c.computedSizes[e.GetId()] = rangeCnt
c.computedSizes[e.ID()] = rangeCnt
rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost))
sum = sum.Add(rangeCost)
@ -674,26 +660,26 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.TypeMap[e.GetId()]
func (c *coster) getType(e ast.Expr) *types.Type {
return c.checkedAST.GetType(e.ID())
}
func (c *coster) getPath(e *exprpb.Expr) []string {
return c.exprPath[e.GetId()]
func (c *coster) getPath(e ast.Expr) []string {
return c.exprPath[e.ID()]
}
func (c *coster) addPath(e *exprpb.Expr, path []string) {
c.exprPath[e.GetId()] = path
func (c *coster) addPath(e ast.Expr, path []string) {
c.exprPath[e.ID()] = path
}
func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
func (c *coster) newAstNode(e ast.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
// only provide paths to root vars; omit accumulator vars
path = nil
}
var derivedSize *SizeEstimate
if size, ok := c.computedSizes[e.GetId()]; ok {
if size, ok := c.computedSizes[e.ID()]; ok {
derivedSize = &size
}
return &astNode{

View File

@ -67,7 +67,7 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
return NewAbstractType("optional_type", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used

View File

@ -146,6 +146,14 @@ func (e *Env) LookupIdent(name string) *decls.VariableDecl {
return decl
}
if i, found := e.provider.FindIdent(candidate); found {
if t, ok := i.(*types.Type); ok {
decl := decls.NewVariable(candidate, types.NewTypeTypeWithParam(t))
e.declarations.AddIdent(decl)
return decl
}
}
// Next try to import this as an enum value by splitting the name in a type prefix and
// the enum inside.
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {

View File

@ -15,13 +15,9 @@
package checker
import (
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// typeErrors is a specialization of Errors.
@ -34,9 +30,9 @@ func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string,
name, FormatCELType(field), FormatCELType(value))
}
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex ast.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}
func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
@ -49,7 +45,7 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type
FormatCELType(t))
}
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}
@ -61,9 +57,9 @@ func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName strin
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex ast.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
@ -87,6 +83,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, kind, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected %s type: %v", kind, typeName)
}

View File

@ -17,40 +17,40 @@ package checker
import (
"sort"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/debug"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type semanticAdorner struct {
checks *exprpb.CheckedExpr
checked *ast.AST
}
var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
e, isExpr := elem.(ast.Expr)
if !isExpr {
return result
}
t := a.checks.TypeMap[e.GetId()]
t := a.checked.TypeMap()[e.ID()]
if t != nil {
result += "~"
result += FormatCheckedType(t)
result += FormatCELType(t)
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr,
*exprpb.Expr_CallExpr,
*exprpb.Expr_StructExpr,
*exprpb.Expr_SelectExpr:
if ref, found := a.checks.ReferenceMap[e.GetId()]; found {
if len(ref.GetOverloadId()) == 0 {
switch e.Kind() {
case ast.IdentKind,
ast.CallKind,
ast.ListKind,
ast.StructKind,
ast.SelectKind:
if ref, found := a.checked.ReferenceMap()[e.ID()]; found {
if len(ref.OverloadIDs) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
sort.Strings(ref.OverloadIDs)
for i, overload := range ref.OverloadIDs {
if i == 0 {
result += "^"
} else {
@ -68,7 +68,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string {
// Print returns a string representation of the Expr message,
// annotated with types from the CheckedExpr. The Expr must
// be a sub-expression embedded in the CheckedExpr.
func Print(e *exprpb.Expr, checks *exprpb.CheckedExpr) string {
a := &semanticAdorner{checks: checks}
func Print(e ast.Expr, checked *ast.AST) string {
a := &semanticAdorner{checked: checked}
return debug.ToAdornedDebugString(e, a)
}

View File

@ -41,7 +41,7 @@ func isError(t *types.Type) bool {
func isOptional(t *types.Type) bool {
if t.Kind() == types.OpaqueKind {
return t.TypeName() == "optional"
return t.TypeName() == "optional_type"
}
return false
}
@ -137,7 +137,11 @@ func internalIsAssignable(m *mapping, t1, t2 *types.Type) bool {
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind,
types.AnyKind, types.DurationKind, types.TimestampKind,
types.StructKind:
return t1.IsAssignableType(t2)
// Test whether t2 is assignable from t1. The order of this check won't usually matter;
// however, there may be cases where type capabilities are expanded beyond what is supported
// in the current common/types package. For example, an interface designation for a group of
// Struct types.
return t2.IsAssignableType(t1)
case types.TypeKind:
return kind2 == types.TypeKind
case types.OpaqueKind, types.ListKind, types.MapKind:
@ -256,7 +260,7 @@ func notReferencedIn(m *mapping, t, withinType *types.Type) bool {
return true
}
return notReferencedIn(m, t, wtSub)
case types.OpaqueKind, types.ListKind, types.MapKind:
case types.OpaqueKind, types.ListKind, types.MapKind, types.TypeKind:
for _, pt := range withinType.Parameters() {
if !notReferencedIn(m, t, pt) {
return false
@ -288,7 +292,8 @@ func substitute(m *mapping, t *types.Type, typeParamToDyn bool) *types.Type {
substitute(m, t.Parameters()[1], typeParamToDyn))
case types.TypeKind:
if len(t.Parameters()) > 0 {
return types.NewTypeTypeWithParam(substitute(m, t.Parameters()[0], typeParamToDyn))
tParam := t.Parameters()[0]
return types.NewTypeTypeWithParam(substitute(m, tParam, typeParamToDyn))
}
return t
default:

View File

@ -1,12 +1,7 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = [
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//interpreter:__subpackages__",
],
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
@ -14,10 +9,14 @@ go_library(
name = "go_default_library",
srcs = [
"ast.go",
"conversion.go",
"expr.go",
"factory.go",
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
@ -29,7 +28,9 @@ go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"conversion_test.go",
"expr_test.go",
"navigable_test.go",
],
embed = [
":go_default_library",
@ -48,5 +49,6 @@ go_test(
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
],
)
)

View File

@ -16,74 +16,361 @@
package ast
import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information.
type CheckedAST struct {
Expr *exprpb.Expr
SourceInfo *exprpb.SourceInfo
TypeMap map[int64]*types.Type
ReferenceMap map[int64]*ReferenceInfo
// AST contains a protobuf expression and source info along with CEL-native type and reference information.
type AST struct {
expr Expr
sourceInfo *SourceInfo
typeMap map[int64]*types.Type
refMap map[int64]*ReferenceInfo
}
// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf.
func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap))
for id, ref := range ast.ReferenceMap {
r, err := ReferenceInfoToReferenceExpr(ref)
if err != nil {
return nil, err
}
refMap[id] = r
// Expr returns the root ast.Expr value in the AST.
func (a *AST) Expr() Expr {
if a == nil {
return nilExpr
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap))
for id, typ := range ast.TypeMap {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &exprpb.CheckedExpr{
Expr: ast.Expr,
SourceInfo: ast.SourceInfo,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
return a.expr
}
// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance.
func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ReferenceExprToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
// SourceInfo returns the source metadata associated with the parse / type-check passes.
func (a *AST) SourceInfo() *SourceInfo {
if a == nil {
return nil
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
return a.sourceInfo
}
// GetType returns the type for the expression at the given id, if one exists, else types.DynType.
func (a *AST) GetType(id int64) *types.Type {
if t, found := a.TypeMap()[id]; found {
return t
}
return &CheckedAST{
Expr: checked.GetExpr(),
SourceInfo: checked.GetSourceInfo(),
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
return types.DynType
}
// SetType sets the type of the expression node at the given id.
func (a *AST) SetType(id int64, t *types.Type) {
if a == nil {
return
}
a.typeMap[id] = t
}
// TypeMap returns the map of expression ids to type-checked types.
//
// If the AST is not type-checked, the map will be empty.
func (a *AST) TypeMap() map[int64]*types.Type {
if a == nil {
return map[int64]*types.Type{}
}
return a.typeMap
}
// GetOverloadIDs returns the set of overload function names for a given expression id.
//
// If the expression id is not a function call, or the AST is not type-checked, the result will be empty.
func (a *AST) GetOverloadIDs(id int64) []string {
if ref, found := a.ReferenceMap()[id]; found {
return ref.OverloadIDs
}
return []string{}
}
// ReferenceMap returns the map of expression id to identifier, constant, and function references.
func (a *AST) ReferenceMap() map[int64]*ReferenceInfo {
if a == nil {
return map[int64]*ReferenceInfo{}
}
return a.refMap
}
// SetReference adds a reference to the checked AST type map.
func (a *AST) SetReference(id int64, r *ReferenceInfo) {
if a == nil {
return
}
a.refMap[id] = r
}
// IsChecked returns whether the AST is type-checked.
func (a *AST) IsChecked() bool {
return a != nil && len(a.TypeMap()) > 0
}
// NewAST creates a base AST instance with an ast.Expr and ast.SourceInfo value.
func NewAST(e Expr, sourceInfo *SourceInfo) *AST {
if e == nil {
e = nilExpr
}
return &AST{
expr: e,
sourceInfo: sourceInfo,
typeMap: make(map[int64]*types.Type),
refMap: make(map[int64]*ReferenceInfo),
}
}
// NewCheckedAST wraps an parsed AST and augments it with type and reference metadata.
func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64]*ReferenceInfo) *AST {
return &AST{
expr: parsed.Expr(),
sourceInfo: parsed.SourceInfo(),
typeMap: typeMap,
refMap: refMap,
}
}
// Copy creates a deep copy of the Expr and SourceInfo values in the input AST.
//
// Copies of the Expr value are generated using an internal default ExprFactory.
func Copy(a *AST) *AST {
if a == nil {
return nil
}
e := defaultFactory.CopyExpr(a.expr)
if !a.IsChecked() {
return NewAST(e, CopySourceInfo(a.SourceInfo()))
}
typesCopy := make(map[int64]*types.Type, len(a.typeMap))
for id, t := range a.typeMap {
typesCopy[id] = t
}
refsCopy := make(map[int64]*ReferenceInfo, len(a.refMap))
for id, r := range a.refMap {
refsCopy[id] = r
}
return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy)
}
// MaxID returns the upper-bound, non-inclusive, of ids present within the AST's Expr value.
func MaxID(a *AST) int64 {
visitor := &maxIDVisitor{maxID: 1}
PostOrderVisit(a.Expr(), visitor)
for id, call := range a.SourceInfo().MacroCalls() {
PostOrderVisit(call, visitor)
if id > visitor.maxID {
visitor.maxID = id + 1
}
}
return visitor.maxID + 1
}
// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
var desc string
baseLine := int32(0)
baseCol := int32(0)
if src != nil {
desc = src.Description()
lineOffsets = src.LineOffsets()
// Determine whether the source metadata should be computed relative
// to a base line and column value. This can be determined by requesting
// the location for offset 0 from the source object.
if loc, found := src.OffsetLocation(0); found {
baseLine = int32(loc.Line()) - 1
baseCol = int32(loc.Column())
}
}
return &SourceInfo{
desc: desc,
lines: lineOffsets,
baseLine: baseLine,
baseCol: baseCol,
offsetRanges: make(map[int64]OffsetRange),
macroCalls: make(map[int64]Expr),
}
}
// CopySourceInfo creates a deep copy of the MacroCalls within the input SourceInfo.
//
// Copies of macro Expr values are generated using an internal default ExprFactory.
func CopySourceInfo(info *SourceInfo) *SourceInfo {
if info == nil {
return nil
}
rangesCopy := make(map[int64]OffsetRange, len(info.offsetRanges))
for id, off := range info.offsetRanges {
rangesCopy[id] = off
}
callsCopy := make(map[int64]Expr, len(info.macroCalls))
for id, call := range info.macroCalls {
callsCopy[id] = defaultFactory.CopyExpr(call)
}
return &SourceInfo{
syntax: info.syntax,
desc: info.desc,
lines: info.lines,
baseLine: info.baseLine,
baseCol: info.baseCol,
offsetRanges: rangesCopy,
macroCalls: callsCopy,
}
}
// SourceInfo records basic information about the expression as a textual input and
// as a parsed expression value.
type SourceInfo struct {
syntax string
desc string
lines []int32
baseLine int32
baseCol int32
offsetRanges map[int64]OffsetRange
macroCalls map[int64]Expr
}
// SyntaxVersion returns the syntax version associated with the text expression.
func (s *SourceInfo) SyntaxVersion() string {
if s == nil {
return ""
}
return s.syntax
}
// Description provides information about where the expression came from.
func (s *SourceInfo) Description() string {
if s == nil {
return ""
}
return s.desc
}
// LineOffsets returns a list of the 0-based character offsets in the input text where newlines appear.
func (s *SourceInfo) LineOffsets() []int32 {
if s == nil {
return []int32{}
}
return s.lines
}
// MacroCalls returns a map of expression id to ast.Expr value where the id represents the expression
// node where the macro was inserted into the AST, and the ast.Expr value represents the original call
// signature which was replaced.
func (s *SourceInfo) MacroCalls() map[int64]Expr {
if s == nil {
return map[int64]Expr{}
}
return s.macroCalls
}
// GetMacroCall returns the original ast.Expr value for the given expression if it was generated via
// a macro replacement.
//
// Note, parsing options must be enabled to track macro calls before this method will return a value.
func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) {
e, found := s.MacroCalls()[id]
return e, found
}
// SetMacroCall records a macro call at a specific location.
func (s *SourceInfo) SetMacroCall(id int64, e Expr) {
if s != nil {
s.macroCalls[id] = e
}
}
// ClearMacroCall removes the macro call at the given expression id.
func (s *SourceInfo) ClearMacroCall(id int64) {
if s != nil {
delete(s.macroCalls, id)
}
}
// OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either:
// the start and end position in the input stream where the expression occurs, or the start position
// only. If the range only captures start position, the stop position of the range will be equal to
// the start.
func (s *SourceInfo) OffsetRanges() map[int64]OffsetRange {
if s == nil {
return map[int64]OffsetRange{}
}
return s.offsetRanges
}
// GetOffsetRange retrieves an OffsetRange for the given expression id if one exists.
func (s *SourceInfo) GetOffsetRange(id int64) (OffsetRange, bool) {
if s == nil {
return OffsetRange{}, false
}
o, found := s.offsetRanges[id]
return o, found
}
// SetOffsetRange sets the OffsetRange for the given expression id.
func (s *SourceInfo) SetOffsetRange(id int64, o OffsetRange) {
if s == nil {
return
}
s.offsetRanges[id] = o
}
// GetStartLocation calculates the human-readable 1-based line and 0-based column of the first character
// of the expression node at the id.
func (s *SourceInfo) GetStartLocation(id int64) common.Location {
if o, found := s.GetOffsetRange(id); found {
line := 1
col := int(o.Start)
for _, lineOffset := range s.LineOffsets() {
if lineOffset < o.Start {
line++
col = int(o.Start - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// GetStopLocation calculates the human-readable 1-based line and 0-based column of the last character for
// the expression node at the given id.
//
// If the SourceInfo was generated from a serialized protobuf representation, the stop location will
// be identical to the start location for the expression.
func (s *SourceInfo) GetStopLocation(id int64) common.Location {
if o, found := s.GetOffsetRange(id); found {
line := 1
col := int(o.Stop)
for _, lineOffset := range s.LineOffsets() {
if lineOffset < o.Stop {
line++
col = int(o.Stop - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// ComputeOffset calculates the 0-based character offset from a 1-based line and 0-based column.
func (s *SourceInfo) ComputeOffset(line, col int32) int32 {
if s != nil {
line = s.baseLine + line
col = s.baseCol + col
}
if line == 1 {
return col
}
if line < 1 || line > int32(len(s.LineOffsets())) {
return -1
}
offset := s.LineOffsets()[line-2]
return offset + col
}
// OffsetRange captures the start and stop positions of a section of text in the input expression.
type OffsetRange struct {
Start int32
Stop int32
}
// ReferenceInfo contains a CEL native representation of an identifier reference which may refer to
@ -149,78 +436,21 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
return true
}
// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
type maxIDVisitor struct {
maxID int64
*baseVisitor
}
// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
// VisitExpr updates the max identifier if the incoming expression id is greater than previously observed.
func (v *maxIDVisitor) VisitExpr(e Expr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed.
func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

View File

@ -0,0 +1,632 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"fmt"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// ToProto converts an AST to a CheckedExpr protobouf.
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
for id, ref := range ast.ReferenceMap() {
r, err := ReferenceInfoToProto(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap()))
for id, typ := range ast.TypeMap() {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
e, err := ExprToProto(ast.Expr())
if err != nil {
return nil, err
}
info, err := SourceInfoToProto(ast.SourceInfo())
if err != nil {
return nil, err
}
return &exprpb.CheckedExpr{
Expr: e,
SourceInfo: info,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// ToAST converts a CheckedExpr protobuf to an AST instance.
func ToAST(checked *exprpb.CheckedExpr) (*AST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ProtoToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
info, err := ProtoToSourceInfo(checked.GetSourceInfo())
if err != nil {
return nil, err
}
root, err := ProtoToExpr(checked.GetExpr())
if err != nil {
return nil, err
}
ast := NewCheckedAST(NewAST(root, info), typeMap, refMap)
return ast, nil
}
// ProtoToExpr converts a protobuf Expr value to an ast.Expr value.
func ProtoToExpr(e *exprpb.Expr) (Expr, error) {
factory := NewExprFactory()
return exprInternal(factory, e)
}
// ProtoToEntryExpr converts a protobuf struct/map entry to an ast.EntryExpr
func ProtoToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
factory := NewExprFactory()
switch e.GetKeyKind().(type) {
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
return exprStructField(factory, e.GetId(), e)
case *exprpb.Expr_CreateStruct_Entry_MapKey:
return exprMapEntry(factory, e.GetId(), e)
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) {
id := e.GetId()
switch e.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
return exprCall(factory, id, e.GetCallExpr())
case *exprpb.Expr_ComprehensionExpr:
return exprComprehension(factory, id, e.GetComprehensionExpr())
case *exprpb.Expr_ConstExpr:
return exprLiteral(factory, id, e.GetConstExpr())
case *exprpb.Expr_IdentExpr:
return exprIdent(factory, id, e.GetIdentExpr())
case *exprpb.Expr_ListExpr:
return exprList(factory, id, e.GetListExpr())
case *exprpb.Expr_SelectExpr:
return exprSelect(factory, id, e.GetSelectExpr())
case *exprpb.Expr_StructExpr:
s := e.GetStructExpr()
if s.GetMessageName() != "" {
return exprStruct(factory, id, s)
}
return exprMap(factory, id, s)
}
return factory.NewUnspecifiedExpr(id), nil
}
func exprCall(factory ExprFactory, id int64, call *exprpb.Expr_Call) (Expr, error) {
var err error
args := make([]Expr, len(call.GetArgs()))
for i, a := range call.GetArgs() {
args[i], err = exprInternal(factory, a)
if err != nil {
return nil, err
}
}
if call.GetTarget() == nil {
return factory.NewCall(id, call.GetFunction(), args...), nil
}
target, err := exprInternal(factory, call.GetTarget())
if err != nil {
return nil, err
}
return factory.NewMemberCall(id, call.GetFunction(), target, args...), nil
}
func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehension) (Expr, error) {
iterRange, err := exprInternal(factory, comp.GetIterRange())
if err != nil {
return nil, err
}
accuInit, err := exprInternal(factory, comp.GetAccuInit())
if err != nil {
return nil, err
}
loopCond, err := exprInternal(factory, comp.GetLoopCondition())
if err != nil {
return nil, err
}
loopStep, err := exprInternal(factory, comp.GetLoopStep())
if err != nil {
return nil, err
}
result, err := exprInternal(factory, comp.GetResult())
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
iterRange,
comp.GetIterVar(),
comp.GetAccuVar(),
accuInit,
loopCond,
loopStep,
result), nil
}
func exprLiteral(factory ExprFactory, id int64, c *exprpb.Constant) (Expr, error) {
val, err := ConstantToVal(c)
if err != nil {
return nil, err
}
return factory.NewLiteral(id, val), nil
}
func exprIdent(factory ExprFactory, id int64, i *exprpb.Expr_Ident) (Expr, error) {
return factory.NewIdent(id, i.GetName()), nil
}
func exprList(factory ExprFactory, id int64, l *exprpb.Expr_CreateList) (Expr, error) {
elems := make([]Expr, len(l.GetElements()))
for i, e := range l.GetElements() {
elem, err := exprInternal(factory, e)
if err != nil {
return nil, err
}
elems[i] = elem
}
return factory.NewList(id, elems, l.GetOptionalIndices()), nil
}
func exprMap(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
entries := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, entry := range s.GetEntries() {
entries[i], err = exprMapEntry(factory, entry.GetId(), entry)
if err != nil {
return nil, err
}
}
return factory.NewMap(id, entries), nil
}
func exprMapEntry(factory ExprFactory, id int64, e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
k, err := exprInternal(factory, e.GetMapKey())
if err != nil {
return nil, err
}
v, err := exprInternal(factory, e.GetValue())
if err != nil {
return nil, err
}
return factory.NewMapEntry(id, k, v, e.GetOptionalEntry()), nil
}
func exprSelect(factory ExprFactory, id int64, s *exprpb.Expr_Select) (Expr, error) {
op, err := exprInternal(factory, s.GetOperand())
if err != nil {
return nil, err
}
if s.GetTestOnly() {
return factory.NewPresenceTest(id, op, s.GetField()), nil
}
return factory.NewSelect(id, op, s.GetField()), nil
}
func exprStruct(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
fields := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, field := range s.GetEntries() {
fields[i], err = exprStructField(factory, field.GetId(), field)
if err != nil {
return nil, err
}
}
return factory.NewStruct(id, s.GetMessageName(), fields), nil
}
func exprStructField(factory ExprFactory, id int64, f *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
v, err := exprInternal(factory, f.GetValue())
if err != nil {
return nil, err
}
return factory.NewStructField(id, f.GetFieldKey(), v, f.GetOptionalEntry()), nil
}
// ExprToProto serializes an ast.Expr value to a protobuf Expr representation.
func ExprToProto(e Expr) (*exprpb.Expr, error) {
if e == nil {
return &exprpb.Expr{}, nil
}
switch e.Kind() {
case CallKind:
return protoCall(e.ID(), e.AsCall())
case ComprehensionKind:
return protoComprehension(e.ID(), e.AsComprehension())
case IdentKind:
return protoIdent(e.ID(), e.AsIdent())
case ListKind:
return protoList(e.ID(), e.AsList())
case LiteralKind:
return protoLiteral(e.ID(), e.AsLiteral())
case MapKind:
return protoMap(e.ID(), e.AsMap())
case SelectKind:
return protoSelect(e.ID(), e.AsSelect())
case StructKind:
return protoStruct(e.ID(), e.AsStruct())
case UnspecifiedExprKind:
// Handle the case where a macro reference may be getting translated.
// A nested macro 'pointer' is a non-zero expression id with no kind set.
if e.ID() != 0 {
return &exprpb.Expr{Id: e.ID()}, nil
}
return &exprpb.Expr{}, nil
}
return nil, fmt.Errorf("unsupported expr kind: %v", e)
}
// EntryExprToProto converts an ast.EntryExpr to a protobuf CreateStruct entry
func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) {
switch e.Kind() {
case MapEntryKind:
return protoMapEntry(e.ID(), e.AsMapEntry())
case StructFieldKind:
return protoStructField(e.ID(), e.AsStructField())
case UnspecifiedEntryExprKind:
return &exprpb.Expr_CreateStruct_Entry{}, nil
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) {
var err error
var target *exprpb.Expr
if call.IsMemberFunction() {
target, err = ExprToProto(call.Target())
if err != nil {
return nil, err
}
}
callArgs := call.Args()
args := make([]*exprpb.Expr, len(callArgs))
for i, a := range callArgs {
args[i], err = ExprToProto(a)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: call.FunctionName(),
Target: target,
Args: args,
},
},
}, nil
}
func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error) {
iterRange, err := ExprToProto(comp.IterRange())
if err != nil {
return nil, err
}
accuInit, err := ExprToProto(comp.AccuInit())
if err != nil {
return nil, err
}
loopCond, err := ExprToProto(comp.LoopCondition())
if err != nil {
return nil, err
}
loopStep, err := ExprToProto(comp.LoopStep())
if err != nil {
return nil, err
}
result, err := ExprToProto(comp.Result())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
LoopCondition: loopCond,
LoopStep: loopStep,
Result: result,
},
},
}, nil
}
func protoIdent(id int64, name string) (*exprpb.Expr, error) {
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: name,
},
},
}, nil
}
func protoList(id int64, list ListExpr) (*exprpb.Expr, error) {
var err error
elems := make([]*exprpb.Expr, list.Size())
for i, e := range list.Elements() {
elems[i], err = ExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elems,
OptionalIndices: list.OptionalIndices(),
},
},
}, nil
}
func protoLiteral(id int64, val ref.Val) (*exprpb.Expr, error) {
c, err := ValToConstant(val)
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: c,
},
}, nil
}
func protoMap(id int64, m MapExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(m.Entries()))
var err error
for i, e := range m.Entries() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
Entries: entries,
},
},
}, nil
}
func protoMapEntry(id int64, e MapEntry) (*exprpb.Expr_CreateStruct_Entry, error) {
k, err := ExprToProto(e.Key())
if err != nil {
return nil, err
}
v, err := ExprToProto(e.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: k,
},
Value: v,
OptionalEntry: e.IsOptional(),
}, nil
}
func protoSelect(id int64, s SelectExpr) (*exprpb.Expr, error) {
op, err := ExprToProto(s.Operand())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: op,
Field: s.FieldName(),
TestOnly: s.IsTestOnly(),
},
},
}, nil
}
func protoStruct(id int64, s StructExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(s.Fields()))
var err error
for i, e := range s.Fields() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: s.TypeName(),
Entries: entries,
},
},
}, nil
}
func protoStructField(id int64, f StructField) (*exprpb.Expr_CreateStruct_Entry, error) {
v, err := ExprToProto(f.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: f.Name(),
},
Value: v,
OptionalEntry: f.IsOptional(),
}, nil
}
// SourceInfoToProto serializes an ast.SourceInfo value to a protobuf SourceInfo object.
func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
if info == nil {
return &exprpb.SourceInfo{}, nil
}
sourceInfo := &exprpb.SourceInfo{
SyntaxVersion: info.SyntaxVersion(),
Location: info.Description(),
LineOffsets: info.LineOffsets(),
Positions: make(map[int64]int32, len(info.OffsetRanges())),
MacroCalls: make(map[int64]*exprpb.Expr, len(info.MacroCalls())),
}
for id, offset := range info.OffsetRanges() {
sourceInfo.Positions[id] = offset.Start
}
for id, e := range info.MacroCalls() {
call, err := ExprToProto(e)
if err != nil {
return nil, err
}
sourceInfo.MacroCalls[id] = call
}
return sourceInfo, nil
}
// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value.
func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
sourceInfo := &SourceInfo{
syntax: info.GetSyntaxVersion(),
desc: info.GetLocation(),
lines: info.GetLineOffsets(),
offsetRanges: make(map[int64]OffsetRange, len(info.GetPositions())),
macroCalls: make(map[int64]Expr, len(info.GetMacroCalls())),
}
for id, offset := range info.GetPositions() {
sourceInfo.SetOffsetRange(id, OffsetRange{Start: offset, Stop: offset})
}
for id, e := range info.GetMacroCalls() {
call, err := ProtoToExpr(e)
if err != nil {
return nil, err
}
sourceInfo.SetMacroCall(id, call)
}
return sourceInfo, nil
}
// ReferenceInfoToProto converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToProto(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
}
// ProtoToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ProtoToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

File diff suppressed because it is too large Load Diff

303
vendor/github.com/google/cel-go/common/ast/factory.go generated vendored Normal file
View File

@ -0,0 +1,303 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import "github.com/google/cel-go/common/types/ref"
// ExprFactory interfaces defines a set of methods necessary for building native expression values.
type ExprFactory interface {
// CopyExpr creates a deep copy of the input Expr value.
CopyExpr(Expr) Expr
// CopyEntryExpr creates a deep copy of the input EntryExpr value.
CopyEntryExpr(EntryExpr) EntryExpr
// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr
// NewComprehension creates an Expr value representing a comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr
// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr
// NewIdent creates an Expr value representing an identifier.
NewIdent(id int64, name string) Expr
// NewAccuIdent creates an Expr value representing an accumulator identifier within a
//comprehension.
NewAccuIdent(id int64) Expr
// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
NewLiteral(id int64, value ref.Val) Expr
// NewList creates an Expr value representing a list literal expression with optional indices.
//
// Optional indicies will typically be empty unless the CEL optional types are enabled.
NewList(id int64, elems []Expr, optIndices []int32) Expr
// NewMap creates an Expr value representing a map literal expression
NewMap(id int64, entries []EntryExpr) Expr
// NewMapEntry creates a MapEntry with a given key, value, and a flag indicating whether
// the key is optionally set.
NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr
// NewPresenceTest creates an Expr representing a field presence test on an operand expression.
NewPresenceTest(id int64, operand Expr, field string) Expr
// NewSelect creates an Expr representing a field selection on an operand expression.
NewSelect(id int64, operand Expr, field string) Expr
// NewStruct creates an Expr value representing a struct literal with a given type name and a
// set of field initializers.
NewStruct(id int64, typeName string, fields []EntryExpr) Expr
// NewStructField creates a StructField with a given field name, value, and a flag indicating
// whether the field is optionally set.
NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr
// NewUnspecifiedExpr creates an empty expression node.
NewUnspecifiedExpr(id int64) Expr
isExprFactory()
}
type baseExprFactory struct{}
// NewExprFactory creates an ExprFactory instance.
func NewExprFactory() ExprFactory {
return &baseExprFactory{}
}
func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: nilExpr,
args: args,
isMember: false,
})
}
func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: target,
args: args,
isMember: true,
})
}
func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
loopStep: loopStep,
result: result,
})
}
func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
return fac.newExpr(id, baseIdentExpr(name))
}
func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
return fac.NewIdent(id, "__result__")
}
func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {
return fac.newExpr(id, &baseLiteral{Val: value})
}
func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr {
optIndexMap := make(map[int32]struct{}, len(optIndices))
for _, idx := range optIndices {
optIndexMap[idx] = struct{}{}
}
return fac.newExpr(id,
&baseListExpr{
elements: elems,
optIndices: optIndices,
optIndexMap: optIndexMap,
})
}
func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr {
return fac.newExpr(id, &baseMapExpr{entries: entries})
}
func (fac *baseExprFactory) NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseMapEntry{
key: key,
value: value,
isOptional: isOptional,
})
}
func (fac *baseExprFactory) NewPresenceTest(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
testOnly: true,
})
}
func (fac *baseExprFactory) NewSelect(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
})
}
func (fac *baseExprFactory) NewStruct(id int64, typeName string, fields []EntryExpr) Expr {
return fac.newExpr(
id,
&baseStructExpr{
typeName: typeName,
fields: fields,
})
}
func (fac *baseExprFactory) NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseStructField{
field: field,
value: value,
isOptional: isOptional,
})
}
func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr {
return fac.newExpr(id, nil)
}
func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
// unwrap navigable expressions to avoid unnecessary allocations during copying.
if nav, ok := e.(*navigableExprImpl); ok {
e = nav.Expr
}
switch e.Kind() {
case CallKind:
c := e.AsCall()
argsCopy := make([]Expr, len(c.Args()))
for i, arg := range c.Args() {
argsCopy[i] = fac.CopyExpr(arg)
}
if !c.IsMemberFunction() {
return fac.NewCall(e.ID(), c.FunctionName(), argsCopy...)
}
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),
fac.CopyExpr(compre.LoopStep()),
fac.CopyExpr(compre.Result()))
case IdentKind:
return fac.NewIdent(e.ID(), e.AsIdent())
case ListKind:
l := e.AsList()
elemsCopy := make([]Expr, l.Size())
for i, elem := range l.Elements() {
elemsCopy[i] = fac.CopyExpr(elem)
}
return fac.NewList(e.ID(), elemsCopy, l.OptionalIndices())
case LiteralKind:
return fac.NewLiteral(e.ID(), e.AsLiteral())
case MapKind:
m := e.AsMap()
entriesCopy := make([]EntryExpr, m.Size())
for i, entry := range m.Entries() {
entriesCopy[i] = fac.CopyEntryExpr(entry)
}
return fac.NewMap(e.ID(), entriesCopy)
case SelectKind:
s := e.AsSelect()
if s.IsTestOnly() {
return fac.NewPresenceTest(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
}
return fac.NewSelect(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
case StructKind:
s := e.AsStruct()
fieldsCopy := make([]EntryExpr, len(s.Fields()))
for i, field := range s.Fields() {
fieldsCopy[i] = fac.CopyEntryExpr(field)
}
return fac.NewStruct(e.ID(), s.TypeName(), fieldsCopy)
default:
return fac.NewUnspecifiedExpr(e.ID())
}
}
func (fac *baseExprFactory) CopyEntryExpr(e EntryExpr) EntryExpr {
switch e.Kind() {
case MapEntryKind:
entry := e.AsMapEntry()
return fac.NewMapEntry(e.ID(),
fac.CopyExpr(entry.Key()), fac.CopyExpr(entry.Value()), entry.IsOptional())
case StructFieldKind:
field := e.AsStructField()
return fac.NewStructField(e.ID(),
field.Name(), fac.CopyExpr(field.Value()), field.IsOptional())
default:
return fac.newEntryExpr(e.ID(), nil)
}
}
func (*baseExprFactory) isExprFactory() {}
func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr {
return &expr{
id: id,
exprKindCase: e,
}
}
func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExpr {
return &entryExpr{
id: id,
entryExprKindCase: e,
}
}
var (
defaultFactory = &baseExprFactory{}
)

652
vendor/github.com/google/cel-go/common/ast/navigable.go generated vendored Normal file
View File

@ -0,0 +1,652 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// NavigableExpr represents the base navigable expression value with methods to inspect the
// parent and child expressions.
type NavigableExpr interface {
Expr
// Type of the expression.
//
// If the expression is type-checked, the type check metadata is returned. If the expression
// has not been type-checked, the types.DynType value is returned.
Type() *types.Type
// Parent returns the parent expression node, if one exists.
Parent() (NavigableExpr, bool)
// Children returns a list of child expression nodes.
Children() []NavigableExpr
// Depth indicates the depth in the expression tree.
//
// The root expression has depth 0.
Depth() int
}
// NavigateAST converts an AST to a NavigableExpr
func NavigateAST(ast *AST) NavigableExpr {
return NavigateExpr(ast, ast.Expr())
}
// NavigateExpr creates a NavigableExpr whose type information is backed by the input AST.
//
// If the expression is already a NavigableExpr, the parent and depth information will be
// propagated on the new NavigableExpr value; otherwise, the expr value will be treated
// as though it is the root of the expression graph with a depth of 0.
func NavigateExpr(ast *AST, expr Expr) NavigableExpr {
depth := 0
var parent NavigableExpr = nil
if nav, ok := expr.(NavigableExpr); ok {
depth = nav.Depth()
parent, _ = nav.Parent()
}
return newNavigableExpr(ast, parent, expr, depth)
}
// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match.
//
// This function type should be use with the `Match` and `MatchList` calls.
type ExprMatcher func(NavigableExpr) bool
// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr
// is comprised of all constant values, such as a simple literal or even list and map literal.
func ConstantValueMatcher() ExprMatcher {
return matchIsConstantValue
}
// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches
// the specified `kind`.
func KindMatcher(kind ExprKind) ExprMatcher {
return func(e NavigableExpr) bool {
return e.Kind() == kind
}
}
// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
return false
}
return e.AsCall().FunctionName() == funcName
}
}
// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list.
//
// Such a result would work well with subsequent MatchList calls.
func AllMatcher() ExprMatcher {
return func(NavigableExpr) bool {
return true
}
}
// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values
// matching the input criteria in post-order (bottom up).
func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr {
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
visit(expr, navVisitor, postOrder, 0, 0)
return matches
}
// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a
// subset of NavigableExpr values which match.
func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr {
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
for _, expr := range exprs {
visit(expr, navVisitor, postOrder, 0, 1)
}
return matches
}
// Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph.
type Visitor interface {
// VisitExpr visits the input expression.
VisitExpr(Expr)
// VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry.
VisitEntryExpr(EntryExpr)
}
type baseVisitor struct {
visitExpr func(Expr)
visitEntryExpr func(EntryExpr)
}
// VisitExpr visits the Expr if the internal expr visitor has been configured.
func (v *baseVisitor) VisitExpr(e Expr) {
if v.visitExpr != nil {
v.visitExpr(e)
}
}
// VisitEntryExpr visits the entry if the internal expr entry visitor has been configured.
func (v *baseVisitor) VisitEntryExpr(e EntryExpr) {
if v.visitEntryExpr != nil {
v.visitEntryExpr(e)
}
}
// NewExprVisitor creates a visitor which only visits expression nodes.
func NewExprVisitor(v func(Expr)) Visitor {
return &baseVisitor{
visitExpr: v,
visitEntryExpr: nil,
}
}
// PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up).
func PostOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, postOrder, 0, 0)
}
// PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down).
func PreOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, preOrder, 0, 0)
}
type visitOrder int
const (
preOrder = iota + 1
postOrder
)
// TODO: consider exposing a way to configure a limit for the max visit depth.
// It's possible that we could want to configure this on the NewExprVisitor()
// and through MatchDescendents() / MaxID().
func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
if maxDepth > 0 && depth == maxDepth {
return
}
if order == preOrder {
visitor.VisitExpr(expr)
}
switch expr.Kind() {
case CallKind:
c := expr.AsCall()
if c.IsMemberFunction() {
visit(c.Target(), visitor, order, depth+1, maxDepth)
}
for _, arg := range c.Args() {
visit(arg, visitor, order, depth+1, maxDepth)
}
case ComprehensionKind:
c := expr.AsComprehension()
visit(c.IterRange(), visitor, order, depth+1, maxDepth)
visit(c.AccuInit(), visitor, order, depth+1, maxDepth)
visit(c.LoopCondition(), visitor, order, depth+1, maxDepth)
visit(c.LoopStep(), visitor, order, depth+1, maxDepth)
visit(c.Result(), visitor, order, depth+1, maxDepth)
case ListKind:
l := expr.AsList()
for _, elem := range l.Elements() {
visit(elem, visitor, order, depth+1, maxDepth)
}
case MapKind:
m := expr.AsMap()
for _, e := range m.Entries() {
if order == preOrder {
visitor.VisitEntryExpr(e)
}
entry := e.AsMapEntry()
visit(entry.Key(), visitor, order, depth+1, maxDepth)
visit(entry.Value(), visitor, order, depth+1, maxDepth)
if order == postOrder {
visitor.VisitEntryExpr(e)
}
}
case SelectKind:
visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth)
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
visitor.VisitEntryExpr(f)
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
}
}
if order == postOrder {
visitor.VisitExpr(expr)
}
}
func matchIsConstantValue(e NavigableExpr) bool {
if e.Kind() == LiteralKind {
return true
}
if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind {
for _, child := range e.Children() {
if !matchIsConstantValue(child) {
return false
}
}
return true
}
return false
}
func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr, depth int) NavigableExpr {
// Reduce navigable expression nesting by unwrapping the embedded Expr value.
if nav, ok := expr.(*navigableExprImpl); ok {
expr = nav.Expr
}
nav := &navigableExprImpl{
Expr: expr,
depth: depth,
ast: ast,
parent: parent,
createChildren: getChildFactory(expr),
}
return nav
}
type navigableExprImpl struct {
Expr
depth int
ast *AST
parent NavigableExpr
createChildren childFactory
}
func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) {
if nav.parent != nil {
return nav.parent, true
}
return nil, false
}
func (nav *navigableExprImpl) ID() int64 {
return nav.Expr.ID()
}
func (nav *navigableExprImpl) Kind() ExprKind {
return nav.Expr.Kind()
}
func (nav *navigableExprImpl) Type() *types.Type {
return nav.ast.GetType(nav.ID())
}
func (nav *navigableExprImpl) Children() []NavigableExpr {
return nav.createChildren(nav)
}
func (nav *navigableExprImpl) Depth() int {
return nav.depth
}
func (nav *navigableExprImpl) AsCall() CallExpr {
return navigableCallImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsComprehension() ComprehensionExpr {
return navigableComprehensionImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsIdent() string {
return nav.Expr.AsIdent()
}
func (nav *navigableExprImpl) AsList() ListExpr {
return navigableListImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsLiteral() ref.Val {
return nav.Expr.AsLiteral()
}
func (nav *navigableExprImpl) AsMap() MapExpr {
return navigableMapImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsSelect() SelectExpr {
return navigableSelectImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsStruct() StructExpr {
return navigableStructImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr {
return newNavigableExpr(nav.ast, nav, e, nav.depth+1)
}
func (nav *navigableExprImpl) isExpr() {}
type navigableCallImpl struct {
*navigableExprImpl
}
func (call navigableCallImpl) FunctionName() string {
return call.Expr.AsCall().FunctionName()
}
func (call navigableCallImpl) IsMemberFunction() bool {
return call.Expr.AsCall().IsMemberFunction()
}
func (call navigableCallImpl) Target() Expr {
t := call.Expr.AsCall().Target()
if t != nil {
return call.createChild(t)
}
return nil
}
func (call navigableCallImpl) Args() []Expr {
args := call.Expr.AsCall().Args()
navArgs := make([]Expr, len(args))
for i, a := range args {
navArgs[i] = call.createChild(a)
}
return navArgs
}
type navigableComprehensionImpl struct {
*navigableExprImpl
}
func (comp navigableComprehensionImpl) IterRange() Expr {
return comp.createChild(comp.Expr.AsComprehension().IterRange())
}
func (comp navigableComprehensionImpl) IterVar() string {
return comp.Expr.AsComprehension().IterVar()
}
func (comp navigableComprehensionImpl) AccuVar() string {
return comp.Expr.AsComprehension().AccuVar()
}
func (comp navigableComprehensionImpl) AccuInit() Expr {
return comp.createChild(comp.Expr.AsComprehension().AccuInit())
}
func (comp navigableComprehensionImpl) LoopCondition() Expr {
return comp.createChild(comp.Expr.AsComprehension().LoopCondition())
}
func (comp navigableComprehensionImpl) LoopStep() Expr {
return comp.createChild(comp.Expr.AsComprehension().LoopStep())
}
func (comp navigableComprehensionImpl) Result() Expr {
return comp.createChild(comp.Expr.AsComprehension().Result())
}
type navigableListImpl struct {
*navigableExprImpl
}
func (l navigableListImpl) Elements() []Expr {
pbElems := l.Expr.AsList().Elements()
elems := make([]Expr, len(pbElems))
for i := 0; i < len(pbElems); i++ {
elems[i] = l.createChild(pbElems[i])
}
return elems
}
func (l navigableListImpl) IsOptional(index int32) bool {
return l.Expr.AsList().IsOptional(index)
}
func (l navigableListImpl) OptionalIndices() []int32 {
return l.Expr.AsList().OptionalIndices()
}
func (l navigableListImpl) Size() int {
return l.Expr.AsList().Size()
}
type navigableMapImpl struct {
*navigableExprImpl
}
func (m navigableMapImpl) Entries() []EntryExpr {
mapExpr := m.Expr.AsMap()
entries := make([]EntryExpr, len(mapExpr.Entries()))
for i, e := range mapExpr.Entries() {
entry := e.AsMapEntry()
entries[i] = &entryExpr{
id: e.ID(),
entryExprKindCase: navigableEntryImpl{
key: m.createChild(entry.Key()),
val: m.createChild(entry.Value()),
isOpt: entry.IsOptional(),
},
}
}
return entries
}
func (m navigableMapImpl) Size() int {
return m.Expr.AsMap().Size()
}
type navigableEntryImpl struct {
key NavigableExpr
val NavigableExpr
isOpt bool
}
func (e navigableEntryImpl) Kind() EntryExprKind {
return MapEntryKind
}
func (e navigableEntryImpl) Key() Expr {
return e.key
}
func (e navigableEntryImpl) Value() Expr {
return e.val
}
func (e navigableEntryImpl) IsOptional() bool {
return e.isOpt
}
func (e navigableEntryImpl) renumberIDs(IDGenerator) {}
func (e navigableEntryImpl) isEntryExpr() {}
type navigableSelectImpl struct {
*navigableExprImpl
}
func (sel navigableSelectImpl) FieldName() string {
return sel.Expr.AsSelect().FieldName()
}
func (sel navigableSelectImpl) IsTestOnly() bool {
return sel.Expr.AsSelect().IsTestOnly()
}
func (sel navigableSelectImpl) Operand() Expr {
return sel.createChild(sel.Expr.AsSelect().Operand())
}
type navigableStructImpl struct {
*navigableExprImpl
}
func (s navigableStructImpl) TypeName() string {
return s.Expr.AsStruct().TypeName()
}
func (s navigableStructImpl) Fields() []EntryExpr {
fieldInits := s.Expr.AsStruct().Fields()
fields := make([]EntryExpr, len(fieldInits))
for i, f := range fieldInits {
field := f.AsStructField()
fields[i] = &entryExpr{
id: f.ID(),
entryExprKindCase: navigableFieldImpl{
name: field.Name(),
val: s.createChild(field.Value()),
isOpt: field.IsOptional(),
},
}
}
return fields
}
type navigableFieldImpl struct {
name string
val NavigableExpr
isOpt bool
}
func (f navigableFieldImpl) Kind() EntryExprKind {
return StructFieldKind
}
func (f navigableFieldImpl) Name() string {
return f.name
}
func (f navigableFieldImpl) Value() Expr {
return f.val
}
func (f navigableFieldImpl) IsOptional() bool {
return f.isOpt
}
func (f navigableFieldImpl) renumberIDs(IDGenerator) {}
func (f navigableFieldImpl) isEntryExpr() {}
func getChildFactory(expr Expr) childFactory {
if expr == nil {
return noopFactory
}
switch expr.Kind() {
case LiteralKind:
return noopFactory
case IdentKind:
return noopFactory
case SelectKind:
return selectFactory
case CallKind:
return callArgFactory
case ListKind:
return listElemFactory
case MapKind:
return mapEntryFactory
case StructKind:
return structEntryFactory
case ComprehensionKind:
return comprehensionFactory
default:
return noopFactory
}
}
type childFactory func(*navigableExprImpl) []NavigableExpr
func noopFactory(*navigableExprImpl) []NavigableExpr {
return nil
}
func selectFactory(nav *navigableExprImpl) []NavigableExpr {
return []NavigableExpr{nav.createChild(nav.AsSelect().Operand())}
}
func callArgFactory(nav *navigableExprImpl) []NavigableExpr {
call := nav.Expr.AsCall()
argCount := len(call.Args())
if call.IsMemberFunction() {
argCount++
}
navExprs := make([]NavigableExpr, argCount)
i := 0
if call.IsMemberFunction() {
navExprs[i] = nav.createChild(call.Target())
i++
}
for _, arg := range call.Args() {
navExprs[i] = nav.createChild(arg)
i++
}
return navExprs
}
func listElemFactory(nav *navigableExprImpl) []NavigableExpr {
l := nav.Expr.AsList()
navExprs := make([]NavigableExpr, len(l.Elements()))
for i, e := range l.Elements() {
navExprs[i] = nav.createChild(e)
}
return navExprs
}
func structEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.Expr.AsStruct()
entries := make([]NavigableExpr, len(s.Fields()))
for i, e := range s.Fields() {
f := e.AsStructField()
entries[i] = nav.createChild(f.Value())
}
return entries
}
func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr {
m := nav.Expr.AsMap()
entries := make([]NavigableExpr, len(m.Entries())*2)
j := 0
for _, e := range m.Entries() {
mapEntry := e.AsMapEntry()
entries[j] = nav.createChild(mapEntry.Key())
entries[j+1] = nav.createChild(mapEntry.Value())
j += 2
}
return entries
}
func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr {
compre := nav.Expr.AsComprehension()
return []NavigableExpr{
nav.createChild(compre.IterRange()),
nav.createChild(compre.AccuInit()),
nav.createChild(compre.LoopCondition()),
nav.createChild(compre.LoopStep()),
nav.createChild(compre.Result()),
}
}

View File

@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)
@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)

View File

@ -20,7 +20,7 @@ import (
"fmt"
"strings"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)
var (
@ -297,19 +297,19 @@ func Name(name string) ContainerOption {
// ToQualifiedName converts an expression AST into a qualified name if possible, with a boolean
// 'found' value that indicates if the conversion is successful.
func ToQualifiedName(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
id := e.GetIdentExpr()
return id.GetName(), true
case *exprpb.Expr_SelectExpr:
sel := e.GetSelectExpr()
func ToQualifiedName(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.IdentKind:
id := e.AsIdent()
return id, true
case ast.SelectKind:
sel := e.AsSelect()
// Test only expressions are not valid as qualified names.
if sel.GetTestOnly() {
if sel.IsTestOnly() {
return "", false
}
if qual, found := ToQualifiedName(sel.GetOperand()); found {
return qual + "." + sel.GetField(), true
if qual, found := ToQualifiedName(sel.Operand()); found {
return qual + "." + sel.FieldName(), true
}
}
return "", false

View File

@ -13,6 +13,8 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
],
)

View File

@ -22,7 +22,9 @@ import (
"strconv"
"strings"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// Adorner returns debug metadata that will be tacked on to the string
@ -38,7 +40,7 @@ type Writer interface {
// Buffer pushes an expression into an internal queue of expressions to
// write to a string.
Buffer(e *exprpb.Expr)
Buffer(e ast.Expr)
}
type emptyDebugAdorner struct {
@ -51,12 +53,12 @@ func (a *emptyDebugAdorner) GetMetadata(e any) string {
}
// ToDebugString gives the unadorned string representation of the Expr.
func ToDebugString(e *exprpb.Expr) string {
func ToDebugString(e ast.Expr) string {
return ToAdornedDebugString(e, emptyAdorner)
}
// ToAdornedDebugString gives the adorned string representation of the Expr.
func ToAdornedDebugString(e *exprpb.Expr, adorner Adorner) string {
func ToAdornedDebugString(e ast.Expr, adorner Adorner) string {
w := newDebugWriter(adorner)
w.Buffer(e)
return w.String()
@ -78,49 +80,51 @@ func newDebugWriter(a Adorner) *debugWriter {
}
}
func (w *debugWriter) Buffer(e *exprpb.Expr) {
func (w *debugWriter) Buffer(e ast.Expr) {
if e == nil {
return
}
switch e.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
w.append(formatLiteral(e.GetConstExpr()))
case *exprpb.Expr_IdentExpr:
w.append(e.GetIdentExpr().Name)
case *exprpb.Expr_SelectExpr:
w.appendSelect(e.GetSelectExpr())
case *exprpb.Expr_CallExpr:
w.appendCall(e.GetCallExpr())
case *exprpb.Expr_ListExpr:
w.appendList(e.GetListExpr())
case *exprpb.Expr_StructExpr:
w.appendStruct(e.GetStructExpr())
case *exprpb.Expr_ComprehensionExpr:
w.appendComprehension(e.GetComprehensionExpr())
switch e.Kind() {
case ast.LiteralKind:
w.append(formatLiteral(e.AsLiteral()))
case ast.IdentKind:
w.append(e.AsIdent())
case ast.SelectKind:
w.appendSelect(e.AsSelect())
case ast.CallKind:
w.appendCall(e.AsCall())
case ast.ListKind:
w.appendList(e.AsList())
case ast.MapKind:
w.appendMap(e.AsMap())
case ast.StructKind:
w.appendStruct(e.AsStruct())
case ast.ComprehensionKind:
w.appendComprehension(e.AsComprehension())
}
w.adorn(e)
}
func (w *debugWriter) appendSelect(sel *exprpb.Expr_Select) {
w.Buffer(sel.GetOperand())
func (w *debugWriter) appendSelect(sel ast.SelectExpr) {
w.Buffer(sel.Operand())
w.append(".")
w.append(sel.GetField())
if sel.TestOnly {
w.append(sel.FieldName())
if sel.IsTestOnly() {
w.append("~test-only~")
}
}
func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
if call.Target != nil {
w.Buffer(call.GetTarget())
func (w *debugWriter) appendCall(call ast.CallExpr) {
if call.IsMemberFunction() {
w.Buffer(call.Target())
w.append(".")
}
w.append(call.GetFunction())
w.append(call.FunctionName())
w.append("(")
if len(call.GetArgs()) > 0 {
if len(call.Args()) > 0 {
w.addIndent()
w.appendLine()
for i, arg := range call.GetArgs() {
for i, arg := range call.Args() {
if i > 0 {
w.append(",")
w.appendLine()
@ -133,12 +137,12 @@ func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
w.append(")")
}
func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
func (w *debugWriter) appendList(list ast.ListExpr) {
w.append("[")
if len(list.GetElements()) > 0 {
if len(list.Elements()) > 0 {
w.appendLine()
w.addIndent()
for i, elem := range list.GetElements() {
for i, elem := range list.Elements() {
if i > 0 {
w.append(",")
w.appendLine()
@ -151,32 +155,25 @@ func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
w.append("]")
}
func (w *debugWriter) appendStruct(obj *exprpb.Expr_CreateStruct) {
if obj.MessageName != "" {
w.appendObject(obj)
} else {
w.appendMap(obj)
}
}
func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(obj.GetMessageName())
func (w *debugWriter) appendStruct(obj ast.StructExpr) {
w.append(obj.TypeName())
w.append("{")
if len(obj.GetEntries()) > 0 {
if len(obj.Fields()) > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, f := range obj.Fields() {
field := f.AsStructField()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if field.IsOptional() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(field.Name())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(field.Value())
w.adorn(f)
}
w.removeIndent()
w.appendLine()
@ -184,23 +181,24 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append("}")
}
func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
func (w *debugWriter) appendMap(m ast.MapExpr) {
w.append("{")
if len(obj.GetEntries()) > 0 {
if m.Size() > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, e := range m.Entries() {
entry := e.AsMapEntry()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if entry.IsOptional() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.Buffer(entry.Key())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(entry.Value())
w.adorn(e)
}
w.removeIndent()
w.appendLine()
@ -208,62 +206,62 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append("}")
}
func (w *debugWriter) appendComprehension(comprehension *exprpb.Expr_Comprehension) {
func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) {
w.append("__comprehension__(")
w.addIndent()
w.appendLine()
w.append("// Variable")
w.appendLine()
w.append(comprehension.GetIterVar())
w.append(comprehension.IterVar())
w.append(",")
w.appendLine()
w.append("// Target")
w.appendLine()
w.Buffer(comprehension.GetIterRange())
w.Buffer(comprehension.IterRange())
w.append(",")
w.appendLine()
w.append("// Accumulator")
w.appendLine()
w.append(comprehension.GetAccuVar())
w.append(comprehension.AccuVar())
w.append(",")
w.appendLine()
w.append("// Init")
w.appendLine()
w.Buffer(comprehension.GetAccuInit())
w.Buffer(comprehension.AccuInit())
w.append(",")
w.appendLine()
w.append("// LoopCondition")
w.appendLine()
w.Buffer(comprehension.GetLoopCondition())
w.Buffer(comprehension.LoopCondition())
w.append(",")
w.appendLine()
w.append("// LoopStep")
w.appendLine()
w.Buffer(comprehension.GetLoopStep())
w.Buffer(comprehension.LoopStep())
w.append(",")
w.appendLine()
w.append("// Result")
w.appendLine()
w.Buffer(comprehension.GetResult())
w.Buffer(comprehension.Result())
w.append(")")
w.removeIndent()
}
func formatLiteral(c *exprpb.Constant) string {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return fmt.Sprintf("%t", c.GetBoolValue())
case *exprpb.Constant_BytesValue:
return fmt.Sprintf("b\"%s\"", string(c.GetBytesValue()))
case *exprpb.Constant_DoubleValue:
return fmt.Sprintf("%v", c.GetDoubleValue())
case *exprpb.Constant_Int64Value:
return fmt.Sprintf("%d", c.GetInt64Value())
case *exprpb.Constant_StringValue:
return strconv.Quote(c.GetStringValue())
case *exprpb.Constant_Uint64Value:
return fmt.Sprintf("%du", c.GetUint64Value())
case *exprpb.Constant_NullValue:
func formatLiteral(c ref.Val) string {
switch v := c.(type) {
case types.Bool:
return fmt.Sprintf("%t", v)
case types.Bytes:
return fmt.Sprintf("b\"%s\"", string(v))
case types.Double:
return fmt.Sprintf("%v", float64(v))
case types.Int:
return fmt.Sprintf("%d", int64(v))
case types.String:
return strconv.Quote(string(v))
case types.Uint:
return fmt.Sprintf("%du", uint64(v))
case types.Null:
return "null"
default:
panic("Unknown constant type")

View File

@ -64,7 +64,7 @@ func (e *Errors) GetErrors() []*Error {
// Append creates a new Errors object with the current and input errors.
func (e *Errors) Append(errs []*Error) *Errors {
return &Errors{
errors: append(e.errors, errs...),
errors: append(e.errors[:], errs...),
source: e.source,
numErrors: e.numErrors + len(errs),
maxErrorsToReport: e.maxErrorsToReport,

View File

@ -31,6 +31,7 @@ type Error interface {
// Err type which extends the built-in go error and implements ref.Val.
type Err struct {
error
id int64
}
var (
@ -58,7 +59,24 @@ var (
// NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)}
return &Err{error: fmt.Errorf(format, args...)}
}
// NewErrWithNodeID creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErrWithNodeID(id int64, format string, args ...any) ref.Val {
return &Err{error: fmt.Errorf(format, args...), id: id}
}
// LabelErrNode returns val unaltered it is not an Err or if the error has a non-zero
// AST node ID already present. Otherwise the id is added to the error for
// recovery with the Err.NodeID method.
func LabelErrNode(id int64, val ref.Val) ref.Val {
if err, ok := val.(*Err); ok && err.id == 0 {
err.id = id
return err
}
return val
}
// NoSuchOverloadErr returns a new types.Err instance with a no such overload message.
@ -124,6 +142,11 @@ func (e *Err) Value() any {
return e.error
}
// NodeID returns the AST node ID of the expression that returned the error.
func (e *Err) NodeID() int64 {
return e.id
}
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()

View File

@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int8:
v, err := int64ToInt8Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int16:
v, err := int64ToInt16Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:

View File

@ -190,7 +190,13 @@ func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// Allow the element ConvertToNative() function to determine whether conversion is possible.
otherElemType := typeDesc.Elem()
elemCount := l.size
nativeList := reflect.MakeSlice(typeDesc, elemCount, elemCount)
var nativeList reflect.Value
if typeDesc.Kind() == reflect.Array {
nativeList = reflect.New(reflect.ArrayOf(elemCount, typeDesc)).Elem().Index(0)
} else {
nativeList = reflect.MakeSlice(typeDesc, elemCount, elemCount)
}
for i := 0; i < elemCount; i++ {
elem := l.NativeToValue(l.get(i))
nativeElemVal, err := elem.ConvertToNative(otherElemType)

View File

@ -24,7 +24,7 @@ import (
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewOpaqueType("optional")
OptionalType = NewOpaqueType("optional_type")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}

View File

@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) {
return uint64(v), nil
}
// int64ToInt8Checked converts an int64 to an int8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt8Checked(v int64) (int8, error) {
if v < math.MinInt8 || v > math.MaxInt8 {
return 0, errIntOverflow
}
return int8(v), nil
}
// int64ToInt16Checked converts an int64 to an int16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt16Checked(v int64) (int16, error) {
if v < math.MinInt16 || v > math.MaxInt16 {
return 0, errIntOverflow
}
return int16(v), nil
}
// int64ToInt32Checked converts an int64 to an int32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) {
return int32(v), nil
}
// uint64ToUint8Checked converts a uint64 to a uint8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint8Checked(v uint64) (uint8, error) {
if v > math.MaxUint8 {
return 0, errUintOverflow
}
return uint8(v), nil
}
// uint64ToUint16Checked converts a uint64 to a uint16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint16Checked(v uint64) (uint16, error) {
if v > math.MaxUint16 {
return 0, errUintOverflow
}
return uint16(v), nil
}
// uint64ToUint32Checked converts a uint64 to a uint32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.

View File

@ -54,6 +54,10 @@ type Provider interface {
// Returns false if not found.
FindStructType(structType string) (*Type, bool)
// FindStructFieldNames returns thet field names associated with the type, if the type
// is found.
FindStructFieldNames(structType string) ([]string, bool)
// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
@ -154,7 +158,7 @@ func (p *Registry) EnumValue(enumName string) ref.Val {
return Int(enumVal.Value())
}
// FieldFieldType returns the field type for a checked type value. Returns false if
// FindFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
//
// Deprecated: use FindStructFieldType
@ -173,7 +177,24 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType,
GetFrom: field.GetFrom}, true
}
// FieldStructFieldType returns the field type for a checked type value. Returns
// FindStructFieldNames returns the set of field names for the given struct type,
// if the type exists in the registry.
func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return []string{}, false
}
fieldMap := msgType.FieldMap()
fields := make([]string, len(fieldMap))
idx := 0
for f := range fieldMap {
fields[idx] = f
idx++
}
return fields, true
}
// FindStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
@ -255,7 +276,7 @@ func (p *Registry) NewValue(structType string, fields map[string]ref.Val) ref.Va
}
err := msgSetField(msg, field, value)
if err != nil {
return &Err{err}
return &Err{error: err}
}
}
return p.NativeToValue(msg.Interface())
@ -569,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
return NewDynamicMap(a, v), true
// type aliases of primitive types cannot be asserted as that type, but rather need
// to be downcast to int32 before being converted to a CEL representation.
case reflect.Bool:
boolTupe := reflect.TypeOf(false)
return Bool(refValue.Convert(boolTupe).Interface().(bool)), true
case reflect.Int:
intType := reflect.TypeOf(int(0))
return Int(refValue.Convert(intType).Interface().(int)), true
case reflect.Int8:
intType := reflect.TypeOf(int8(0))
return Int(refValue.Convert(intType).Interface().(int8)), true
case reflect.Int16:
intType := reflect.TypeOf(int16(0))
return Int(refValue.Convert(intType).Interface().(int16)), true
case reflect.Int32:
intType := reflect.TypeOf(int32(0))
return Int(refValue.Convert(intType).Interface().(int32)), true
case reflect.Int64:
intType := reflect.TypeOf(int64(0))
return Int(refValue.Convert(intType).Interface().(int64)), true
case reflect.Uint:
uintType := reflect.TypeOf(uint(0))
return Uint(refValue.Convert(uintType).Interface().(uint)), true
case reflect.Uint8:
uintType := reflect.TypeOf(uint8(0))
return Uint(refValue.Convert(uintType).Interface().(uint8)), true
case reflect.Uint16:
uintType := reflect.TypeOf(uint16(0))
return Uint(refValue.Convert(uintType).Interface().(uint16)), true
case reflect.Uint32:
uintType := reflect.TypeOf(uint32(0))
return Uint(refValue.Convert(uintType).Interface().(uint32)), true
@ -587,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
case reflect.Float64:
doubleType := reflect.TypeOf(float64(0))
return Double(refValue.Convert(doubleType).Interface().(float64)), true
case reflect.String:
stringType := reflect.TypeOf("")
return String(refValue.Convert(stringType).Interface().(string)), true
}
}
return nil, false

View File

@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val {
func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.String:
if reflect.TypeOf(s).AssignableTo(typeDesc) {
return s, nil
}
return s.Value(), nil
return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
switch typeDesc {
case anyValueType:
@ -158,7 +155,7 @@ func (s String) Match(pattern ref.Val) ref.Val {
}
matched, err := regexp.MatchString(pat.Value().(string), s.Value().(string))
if err != nil {
return &Err{err}
return &Err{error: err}
}
return Bool(matched)
}

View File

@ -373,6 +373,23 @@ func (t *Type) TypeName() string {
return t.runtimeTypeName
}
// WithTraits creates a copy of the current Type and sets the trait mask to the traits parameter.
//
// This method should be used with Opaque types where the type acts like a container, e.g. vector.
func (t *Type) WithTraits(traits int) *Type {
if t == nil {
return nil
}
return &Type{
kind: t.kind,
parameters: t.parameters,
runtimeTypeName: t.runtimeTypeName,
isAssignableType: t.isAssignableType,
isAssignableRuntimeType: t.isAssignableRuntimeType,
traitMask: traits,
}
}
// String returns a human-readable definition of the type name.
func (t *Type) String() string {
if len(t.Parameters()) == 0 {
@ -496,7 +513,7 @@ func NewNullableType(wrapped *Type) *Type {
// NewOptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func NewOptionalType(param *Type) *Type {
return NewOpaqueType("optional", param)
return NewOpaqueType("optional_type", param)
}
// NewOpaqueType creates an abstract parameterized type with a given name.

View File

@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint8:
v, err := uint64ToUint8Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint16:
v, err := uint64ToUint16Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:

View File

@ -7,7 +7,9 @@ package(
go_library(
name = "go_default_library",
srcs = [
"bindings.go",
"encoders.go",
"formatting.go",
"guards.go",
"lists.go",
"math.go",
@ -21,14 +23,14 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
@ -61,7 +63,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",

View File

@ -414,3 +414,17 @@ Examples:
'TacoCat'.upperAscii() // returns 'TACOCAT'
'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
### Reverse
Returns a new string whose characters are the same as the target string, only formatted in
reverse order.
This function relies on converting strings to rune arrays in order to reverse.
It can be located in Version 3 of strings.
<string>.reverse() -> <string>
Examples:
'gums'.reverse() // returns 'smug'
'John Smith'.reverse() // returns 'htimS nhoJ'

View File

@ -16,8 +16,8 @@ package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
)
// Bindings returns a cel.EnvOption to configure support for local variable
@ -61,7 +61,7 @@ func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.NewReceiverMacro(bindMacro, 3, celBind),
cel.ReceiverMacro(bindMacro, 3, celBind),
),
}
}
@ -70,27 +70,27 @@ func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers")
return nil, mef.NewError(varIdent.ID(), "cel.bind() variable names must be simple identifiers")
}
varInit := args[1]
resultExpr := args[2]
return meh.Fold(
return mef.NewComprehension(
mef.NewList(),
unusedIterVar,
meh.NewList(),
varName,
varInit,
meh.LiteralBool(false),
meh.Ident(varName),
mef.NewLiteral(types.False),
mef.NewIdent(varName),
resultExpr,
), nil
}

904
vendor/github.com/google/cel-go/ext/formatting.go generated vendored Normal file
View File

@ -0,0 +1,904 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"errors"
"fmt"
"math"
"sort"
"strconv"
"strings"
"unicode"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner suitable
// for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
//
// It is used to implement the %s specifier in the (string).format() extension function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", stringFormatError(runtimeID, arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", decimalFormatError(runtimeID, arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fixedPointFormatError(runtimeID, arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", scientificFormatError(runtimeID, arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%v\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", binaryFormatError(runtimeID, arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", hexFormatError(runtimeID, arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", octalFormatError(runtimeID, arg.Type().TypeName())
}
}
// stringFormatValidator implements the cel.ASTValidator interface allowing for static validation
// of string.format calls.
type stringFormatValidator struct{}
// Name returns the name of the validator.
func (stringFormatValidator) Name() string {
return "cel.lib.ext.validate.functions.string.format"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
// during homogeneous aggregate literal type-checks.
func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error {
functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
functions = append(functions, "format")
return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
}
// Validate parses all literal format strings and type checks the format clause against the argument
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
root := ast.NavigateAST(a)
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
for _, e := range formatCallExprs {
call := e.AsCall()
formatStr := call.Target().AsLiteral().Value().(string)
args := call.Args()[0].AsList().Elements()
formatCheck := &stringFormatChecker{
args: args,
ast: a,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), err.Error())
continue
}
seenArgs := formatCheck.argsRequested
if len(args) > seenArgs {
iss.ReportErrorAtID(e.ID(),
"too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args))
}
}
}
// getErrorExprID determines which list literal argument triggered a type-disagreement for the
// purposes of more accurate error message reports.
func getErrorExprID(id int64, err error) int64 {
fmtErr, ok := err.(formatError)
if ok {
return fmtErr.id
}
wrapped := errors.Unwrap(err)
if wrapped != nil {
return getErrorExprID(id, wrapped)
}
return id
}
// matchConstantFormatStringWithListLiteralArgs matches all valid expression nodes for string
// format checking.
func matchConstantFormatStringWithListLiteralArgs(a *ast.AST) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() != ast.CallKind {
return false
}
call := e.AsCall()
if !call.IsMemberFunction() || call.FunctionName() != "format" {
return false
}
overloadIDs := a.GetOverloadIDs(e.ID())
if len(overloadIDs) != 0 {
found := false
for _, overload := range overloadIDs {
if overload == overloads.ExtFormatString {
found = true
break
}
}
if !found {
return false
}
}
formatString := call.Target()
if formatString.Kind() != ast.LiteralKind && formatString.AsLiteral().Type() != cel.StringType {
return false
}
args := call.Args()
if len(args) != 1 {
return false
}
formatArgs := args[0]
return formatArgs.Kind() == ast.ListKind
}
}
// stringFormatChecker implements the formatStringInterpolater interface
type stringFormatChecker struct {
args []ast.Expr
argsRequested int
currArgIndex int64
ast *ast.AST
}
func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) {
formatArg := c.args[c.currArgIndex]
valid, badID := c.verifyString(formatArg)
if !valid {
return "", stringFormatError(badID, c.typeOf(badID).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
if !valid {
return "", decimalFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
if !valid {
return "", fixedPointFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType)
if !valid {
return "", scientificFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType)
if !valid {
return "", binaryFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if !valid {
return "", hexFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
if !valid {
return "", octalFormatError(id, c.typeOf(id).TypeName())
}
return "", nil
}
func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.currArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *stringFormatChecker) Size() int64 {
return int64(len(c.args))
}
func (c *stringFormatChecker) typeOf(id int64) *cel.Type {
return c.ast.GetType(id)
}
func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool {
t := c.typeOf(id)
if t == cel.DynType {
return true
}
for _, vt := range validTypes {
// Only check runtime type compatibility without delving deeper into parameterized types
if t.Kind() == vt.Kind() {
return true
}
}
return false
}
func (c *stringFormatChecker) verifyString(sub ast.Expr) (bool, int64) {
paramA := cel.TypeParamType("A")
paramB := cel.TypeParamType("B")
subVerified := c.verifyTypeOneOf(sub.ID(),
cel.ListType(paramA), cel.MapType(paramA, paramB),
cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType,
cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType)
if !subVerified {
return false, sub.ID()
}
switch sub.Kind() {
case ast.ListKind:
for _, e := range sub.AsList().Elements() {
// recursively verify if we're dealing with a list/map
verified, id := c.verifyString(e)
if !verified {
return false, id
}
}
return true, sub.ID()
case ast.MapKind:
for _, e := range sub.AsMap().Entries() {
// recursively verify if we're dealing with a list/map
entry := e.AsMapEntry()
verified, id := c.verifyString(entry.Key())
if !verified {
return false, id
}
verified, id = c.verifyString(entry.Value())
if !verified {
return false, id
}
}
return true, sub.ID()
default:
return true, sub.ID()
}
}
// helper routines for reporting common errors during string formatting static validation and
// runtime execution.
func binaryFormatError(id int64, badType string) error {
return newFormatError(id, "only integers and bools can be formatted as binary, was given %s", badType)
}
func decimalFormatError(id int64, badType string) error {
return newFormatError(id, "decimal clause can only be used on integers, was given %s", badType)
}
func fixedPointFormatError(id int64, badType string) error {
return newFormatError(id, "fixed-point clause can only be used on doubles, was given %s", badType)
}
func hexFormatError(id int64, badType string) error {
return newFormatError(id, "only integers, byte buffers, and strings can be formatted as hex, was given %s", badType)
}
func octalFormatError(id int64, badType string) error {
return newFormatError(id, "octal clause can only be used on integers, was given %s", badType)
}
func scientificFormatError(id int64, badType string) error {
return newFormatError(id, "scientific clause can only be used on doubles, was given %s", badType)
}
func stringFormatError(id int64, badType string) error {
return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType)
}
type formatError struct {
id int64
msg string
}
func newFormatError(id int64, msg string, args ...any) error {
return formatError{
id: id,
msg: fmt.Sprintf(msg, args...),
}
}
func (e formatError) Error() string {
return e.msg
}
func (e formatError) Is(target error) bool {
return e.msg == target.Error()
}
// stringArgList implements the formatListArgs interface.
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) Size() int64 {
return c.args.Size().Value().(int64)
}
// formatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type formatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// formatListArgs is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type formatListArgs interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// Size returns the length of the argument list.
Size() int64
}
// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.Size() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", newParseFormatError("could not parse formatting clause", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", newParseFormatError("error during formatting", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback formatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}
type parseFormatError struct {
msg string
wrapped error
}
func newParseFormatError(msg string, wrapped error) error {
return parseFormatError{msg: msg, wrapped: wrapped}
}
func (e parseFormatError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error())
}
func (e parseFormatError) Is(target error) bool {
return e.Error() == target.Error()
}
func (e parseFormatError) Unwrap() error {
return e.wrapped
}
const (
runtimeID = int64(-1)
)

View File

@ -15,10 +15,9 @@
package ext
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// function invocation guards for common call signatures within extension functions.
@ -51,10 +50,10 @@ func listStringOrError(strs []string, err error) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(strs)
}
func macroTargetMatchesNamespace(ns string, target *exprpb.Expr) bool {
switch target.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
if target.GetIdentExpr().GetName() != ns {
func macroTargetMatchesNamespace(ns string, target ast.Expr) bool {
switch target.Kind() {
case ast.IdentKind:
if target.AsIdent() != ns {
return false
}
return true

View File

@ -19,11 +19,10 @@ import (
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Math returns a cel.EnvOption to configure namespaced math helper macros and
@ -111,9 +110,9 @@ func (mathLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// math.least(num, ...)
cel.NewReceiverVarArgMacro(leastMacro, mathLeast),
cel.ReceiverVarArgMacro(leastMacro, mathLeast),
// math.greatest(num, ...)
cel.NewReceiverVarArgMacro(greatestMacro, mathGreatest),
cel.ReceiverVarArgMacro(greatestMacro, mathGreatest),
),
cel.Function(minFunc,
cel.Overload("math_@min_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
@ -187,57 +186,57 @@ func (mathLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func mathLeast(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument")
return nil, meh.NewError(target.ID(), "math.least() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(minFunc, args[0]), nil
return meh.NewCall(minFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value")
return nil, meh.NewError(args[0].ID(), "math.least() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, args...), nil
return meh.NewCall(minFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, meh.NewList(args...)), nil
return meh.NewCall(minFunc, meh.NewList(args...)), nil
}
}
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func mathGreatest(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument")
return nil, mef.NewError(target.ID(), "math.greatest() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(maxFunc, args[0]), nil
return mef.NewCall(maxFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value")
return nil, mef.NewError(args[0].ID(), "math.greatest() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.greatest()", args)
err := checkInvalidArgs(mef, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, args...), nil
return mef.NewCall(maxFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.greatest()", args)
err := checkInvalidArgs(mef, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, meh.NewList(args...)), nil
return mef.NewCall(maxFunc, mef.NewList(args...)), nil
}
}
@ -311,48 +310,48 @@ func maxList(numList ref.Val) ref.Val {
}
}
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error {
func checkInvalidArgs(meh cel.MacroExprFactory, funcName string, args []ast.Expr) *cel.Error {
for _, arg := range args {
err := checkInvalidArgLiteral(funcName, arg)
if err != nil {
return meh.NewError(arg.GetId(), err.Error())
return meh.NewError(arg.ID(), err.Error())
}
}
return nil
}
func checkInvalidArgLiteral(funcName string, arg *exprpb.Expr) error {
func checkInvalidArgLiteral(funcName string, arg ast.Expr) error {
if !isValidArgType(arg) {
return fmt.Errorf("%s simple literal arguments must be numeric", funcName)
}
return nil
}
func isValidArgType(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
c := arg.GetConstExpr()
switch c.GetConstantKind().(type) {
case *exprpb.Constant_DoubleValue, *exprpb.Constant_Int64Value, *exprpb.Constant_Uint64Value:
func isValidArgType(arg ast.Expr) bool {
switch arg.Kind() {
case ast.LiteralKind:
c := ref.Val(arg.AsLiteral())
switch c.(type) {
case types.Double, types.Int, types.Uint:
return true
default:
return false
}
case *exprpb.Expr_ListExpr, *exprpb.Expr_StructExpr:
case ast.ListKind, ast.MapKind, ast.StructKind:
return false
default:
return true
}
}
func isListLiteralWithValidArgs(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ListExpr:
list := arg.GetListExpr()
if len(list.GetElements()) == 0 {
func isListLiteralWithValidArgs(arg ast.Expr) bool {
switch arg.Kind() {
case ast.ListKind:
list := arg.AsList()
if list.Size() == 0 {
return false
}
for _, e := range list.GetElements() {
for _, e := range list.Elements() {
if !isValidArgType(e) {
return false
}

View File

@ -96,17 +96,21 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
t, err := newNativeType(rt)
result, err := newNativeTypes(rt)
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
t, err := newNativeType(rt.Type())
result, err := newNativeTypes(rt.Type())
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
default:
return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt)
}
@ -151,6 +155,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}
// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
if t, found := tp.nativeTypes[typeName]; found {
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
}
return fields, true
}
if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found {
return celTypeFields, true
}
return tp.baseProvider.FindStructFieldNames(typeName)
}
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
@ -447,6 +469,42 @@ func (o *nativeObj) Value() any {
return o.val
}
func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(rawType)
if err != nil {
return nil, err
}
result := []*nativeType{nt}
alreadySeen := make(map[string]struct{})
var iterateStructMembers func(reflect.Type)
iterateStructMembers = func(t reflect.Type) {
if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return
}
if _, seen := alreadySeen[t.String()]; seen {
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(t)
if ntErr != nil {
err = ntErr
return
}
result = append(result, nt)
for idx := 0; idx < t.NumField(); idx++ {
iterateStructMembers(t.Field(idx).Type)
}
}
iterateStructMembers(rawType)
return result, err
}
func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {

View File

@ -16,8 +16,7 @@ package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)
// Protos returns a cel.EnvOption to configure extended macros and functions for
@ -72,9 +71,9 @@ func (protoLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// proto.getExt(msg, select_expression)
cel.NewReceiverMacro(getExtension, 2, getProtoExt),
cel.ReceiverMacro(getExtension, 2, getProtoExt),
// proto.hasExt(msg, select_expression)
cel.NewReceiverMacro(hasExtension, 2, hasProtoExt),
cel.ReceiverMacro(hasExtension, 2, hasProtoExt),
),
}
}
@ -85,56 +84,56 @@ func (protoLib) ProgramOptions() []cel.ProgramOption {
}
// hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message.
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func hasProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extensionField, err := getExtFieldName(meh, args[1])
extensionField, err := getExtFieldName(mef, args[1])
if err != nil {
return nil, err
}
return meh.PresenceTest(args[0], extensionField), nil
return mef.NewPresenceTest(args[0], extensionField), nil
}
// getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message.
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func getProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extFieldName, err := getExtFieldName(meh, args[1])
extFieldName, err := getExtFieldName(mef, args[1])
if err != nil {
return nil, err
}
return meh.Select(args[0], extFieldName), nil
return mef.NewSelect(args[0], extFieldName), nil
}
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) {
func getExtFieldName(mef cel.MacroExprFactory, expr ast.Expr) (string, *cel.Error) {
isValid := false
extensionField := ""
switch expr.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
switch expr.Kind() {
case ast.SelectKind:
extensionField, isValid = validateIdentifier(expr)
}
if !isValid {
return "", meh.NewError(expr.GetId(), "invalid extension field")
return "", mef.NewError(expr.ID(), "invalid extension field")
}
return extensionField, nil
}
func validateIdentifier(expr *exprpb.Expr) (string, bool) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
return expr.GetIdentExpr().GetName(), true
case *exprpb.Expr_SelectExpr:
sel := expr.GetSelectExpr()
if sel.GetTestOnly() {
func validateIdentifier(expr ast.Expr) (string, bool) {
switch expr.Kind() {
case ast.IdentKind:
return expr.AsIdent(), true
case ast.SelectKind:
sel := expr.AsSelect()
if sel.IsTestOnly() {
return "", false
}
opStr, isIdent := validateIdentifier(sel.GetOperand())
opStr, isIdent := validateIdentifier(sel.Operand())
if !isIdent {
return "", false
}
return opStr + "." + sel.GetField(), true
return opStr + "." + sel.FieldName(), true
default:
return "", false
}

View File

@ -19,6 +19,8 @@ import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -119,6 +121,68 @@ func (setsLib) ProgramOptions() []cel.ProgramOption {
}
}
// NewSetMembershipOptimizer rewrites set membership tests using the `in` operator against a list
// of constant values of enum, int, uint, string, or boolean type into a set membership test against
// a map where the map keys are the elements of the list.
func NewSetMembershipOptimizer() (cel.ASTOptimizer, error) {
return setsLib{}, nil
}
func (setsLib) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
matches := ast.MatchDescendants(root, matchInConstantList(a))
for _, match := range matches {
call := match.AsCall()
listArg := call.Args()[1]
entries := make([]ast.EntryExpr, len(listArg.AsList().Elements()))
for i, elem := range listArg.AsList().Elements() {
var entry ast.EntryExpr
if r, found := a.ReferenceMap()[elem.ID()]; found && r.Value != nil {
entry = ctx.NewMapEntry(ctx.NewLiteral(r.Value), ctx.NewLiteral(types.True), false)
} else {
entry = ctx.NewMapEntry(elem, ctx.NewLiteral(types.True), false)
}
entries[i] = entry
}
mapArg := ctx.NewMap(entries)
ctx.UpdateExpr(listArg, mapArg)
}
return a
}
func matchInConstantList(a *ast.AST) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() != ast.CallKind {
return false
}
call := e.AsCall()
if call.FunctionName() != operators.In {
return false
}
aggregateVal := call.Args()[1]
if aggregateVal.Kind() != ast.ListKind {
return false
}
listVal := aggregateVal.AsList()
for _, elem := range listVal.Elements() {
if r, found := a.ReferenceMap()[elem.ID()]; found {
if r.Value != nil {
continue
}
}
if elem.Kind() != ast.LiteralKind {
return false
}
lit := elem.AsLiteral()
if !(lit.Type() == cel.StringType || lit.Type() == cel.IntType ||
lit.Type() == cel.UintType || lit.Type() == cel.BoolType) {
return false
}
}
return true
}
}
func setsIntersects(listA, listB ref.Val) ref.Val {
lA := listA.(traits.Lister)
lB := listB.(traits.Lister)

View File

@ -21,19 +21,16 @@ import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
const (
@ -99,7 +96,7 @@ const (
// "a map inside a list: %s".format([[1, 2, 3, {"a": "x", "b": "y", "c": "z"}]]) // returns "a map inside a list: [1, 2, 3, {"a":"x", "b":"y", "c":"d"}]"
// "true bool: %s - false bool: %s\nbinary bool: %b".format([true, false, true]) // returns "true bool: true - false bool: false\nbinary bool: 1"
//
// Passing an incorrect type (an integer to `%s`) is considered an error, as well as attempting
// Passing an incorrect type (a string to `%b`) is considered an error, as well as attempting
// to use more formatting clauses than there are arguments (`%d %d %d` while passing two ints, for instance).
// If compile-time checking is enabled, and the formatting string is a constant, and the argument list is a literal,
// then letting any arguments go unused/unformatted is also considered an error.
@ -205,6 +202,8 @@ const (
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
// 'hello hello'.replace('', '_') // returns '_h_e_l_l_o_ _h_e_l_l_o_'
// 'hello hello'.replace('h', '') // returns 'ello ello'
//
// # Split
//
@ -270,8 +269,26 @@ const (
//
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
//
// # Reverse
//
// Introduced at version: 3
//
// Returns a new string whose characters are the same as the target string, only formatted in
// reverse order.
// This function relies on converting strings to rune arrays in order to reverse
//
// <string>.reverse() -> <string>
//
// Examples:
//
// 'gums'.reverse() // returns 'smug'
// 'John Smith'.reverse() // returns 'htimS nhoJ'
func Strings(options ...StringsOption) cel.EnvOption {
s := &stringLib{version: math.MaxUint32}
s := &stringLib{
version: math.MaxUint32,
validateFormat: true,
}
for _, o := range options {
s = o(s)
}
@ -279,8 +296,9 @@ func Strings(options ...StringsOption) cel.EnvOption {
}
type stringLib struct {
locale string
version uint32
locale string
version uint32
validateFormat bool
}
// LibraryName implements the SingletonLibrary interface method.
@ -317,6 +335,17 @@ func StringsVersion(version uint32) StringsOption {
}
}
// StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have
// valid formatting clauses and valid argument types for each clause.
//
// Enabled by default.
func StringsValidateFormatCalls(value bool) StringsOption {
return func(s *stringLib) *stringLib {
s.validateFormat = value
return s
}
}
// CompileOptions implements the Library interface method.
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
@ -440,13 +469,15 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(interpreter.ParseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))),
cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(quote(string(s)))
}))))
}))),
cel.ASTValidators(stringFormatValidator{}))
}
if lib.version >= 2 {
@ -471,7 +502,7 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
return types.WrapErr(err)
}
return stringOrError(join(l.([]string)))
})),
@ -479,13 +510,26 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
return types.WrapErr(err)
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
)
}
if lib.version >= 3 {
opts = append(opts,
cel.Function("reverse",
cel.MemberOverload("reverse", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(reverse(string(s)))
}))),
)
}
if lib.validateFormat {
opts = append(opts, cel.ASTValidators(stringFormatValidator{}))
}
return opts
}
@ -636,6 +680,14 @@ func upperASCII(str string) (string, error) {
return string(runes), nil
}
func reverse(str string) (string, error) {
chars := []rune(str)
for i, j := 0, len(chars)-1; i < j; i, j = i+1, j-1 {
chars[i], chars[j] = chars[j], chars[i]
}
return string(chars), nil
}
func joinSeparator(strs []string, separator string) (string, error) {
return strings.Join(strs, separator), nil
}
@ -661,238 +713,6 @@ func joinValSeparator(strs traits.Lister, separator string) (string, error) {
return sb.String(), nil
}
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner
// suitable for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
// It is used to implement the %s specifier in the (string).format() extension
// function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", fmt.Errorf("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", fmt.Errorf("decimal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
// quote implements a string quoting function. The string will be wrapped in
// double quotes, and all valid CEL escape sequences will be escaped to show up
// literally if printed. If the input contains any invalid UTF-8, the invalid runes
@ -940,156 +760,6 @@ func sanitize(s string) string {
return sanitizedStringBuilder.String()
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("fixed-point clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("scientific clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", fmt.Errorf("only integers and bools can be formatted as binary, was given %s", arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", fmt.Errorf("only integers, byte buffers, and strings can be formatted as hex, was given %s", arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", fmt.Errorf("octal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) ArgSize() int64 {
return c.args.Size().Value().(int64)
}
var (
stringListType = reflect.TypeOf([]string{})
)

View File

@ -14,7 +14,6 @@ go_library(
"decorators.go",
"dispatcher.go",
"evalstate.go",
"formatting.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",

View File

@ -287,6 +287,9 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) {
// determine whether the type is unknown before returning.
obj, found := vars.ResolveName(nm)
if found {
if celErr, ok := obj.(*types.Err); ok {
return nil, celErr.Unwrap()
}
obj, isOpt, err := applyQualifiers(vars, obj, a.qualifiers)
if err != nil {
return nil, err

View File

@ -1,383 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type typeVerifier func(int64, ...ref.Type) (bool, error)
// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports
// any errors at compile time.
func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator {
return func(inter Interpretable) (Interpretable, error) {
call, ok := inter.(InterpretableCall)
if !ok {
return inter, nil
}
if call.OverloadID() != "string_format" {
return inter, nil
}
args := call.Args()
if len(args) != 2 {
return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args))
}
fmtStrInter, ok := args[0].(InterpretableConst)
if !ok {
return inter, nil
}
var fmtArgsInter InterpretableConstructor
fmtArgsInter, ok = args[1].(InterpretableConstructor)
if !ok {
return inter, nil
}
if fmtArgsInter.Type() != types.ListType {
// don't necessarily return an error since the list may be DynType
return inter, nil
}
formatStr := fmtStrInter.Value().Value().(string)
initVals := fmtArgsInter.InitVals()
formatCheck := &formatCheck{
args: initVals,
verifier: verifier,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
return nil, err
}
seenArgs := formatCheck.argsRequested
if len(initVals) > seenArgs {
return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals))
}
return inter, nil
}
}
type formatCheck struct {
args []Interpretable
argsRequested int
curArgIndex int64
enableCheckArgTypes bool
verifier typeVerifier
}
func (c *formatCheck) String(arg ref.Val, locale string) (string, error) {
valid, err := verifyString(c.args[c.curArgIndex], c.verifier)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps")
}
return "", nil
}
func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("integer clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("fixed-point clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("scientific clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers and bools can be formatted as binary")
}
return "", nil
}
func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers, byte buffers, and strings can be formatted as hex")
}
return "", nil
}
}
func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("octal clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.curArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *formatCheck) ArgSize() int64 {
return int64(len(c.args))
}
func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) {
subVerified, err := verifier(sub.ID(),
types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType)
if err != nil {
return false, err
}
if !subVerified {
return false, nil
}
con, ok := sub.(InterpretableConstructor)
if ok {
members := con.InitVals()
for _, m := range members {
// recursively verify if we're dealing with a list/map
verified, err := verifyString(m, verifier)
if err != nil {
return false, err
}
if !verified {
return false, nil
}
}
}
return true, nil
}
// FormatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type FormatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// FormatList is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type FormatList interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// ArgSize returns the length of the argument list.
ArgSize() int64
}
type clauseImpl func(ref.Val, string) (string, error)
// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.ArgSize() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", fmt.Errorf("could not parse formatting clause: %s", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", fmt.Errorf("error during formatting: %s", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}

View File

@ -125,7 +125,7 @@ func (test *evalTestOnly) Eval(ctx Activation) ref.Val {
val, err := test.Resolve(ctx)
// Return an error if the resolve step fails
if err != nil {
return types.WrapErr(err)
return types.LabelErrNode(test.id, types.WrapErr(err))
}
if optVal, isOpt := val.(*types.Optional); isOpt {
return types.Bool(optVal.HasValue())
@ -231,6 +231,7 @@ func (or *evalOr) Eval(ctx Activation) ref.Val {
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
err = types.LabelErrNode(or.id, err)
}
}
}
@ -273,6 +274,7 @@ func (and *evalAnd) Eval(ctx Activation) ref.Val {
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
err = types.LabelErrNode(and.id, err)
}
}
}
@ -377,7 +379,7 @@ func (zero *evalZeroArity) ID() int64 {
// Eval implements the Interpretable interface method.
func (zero *evalZeroArity) Eval(ctx Activation) ref.Val {
return zero.impl()
return types.LabelErrNode(zero.id, zero.impl())
}
// Function implements the InterpretableCall interface method.
@ -421,14 +423,14 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val {
// If the implementation is bound and the argument value has the right traits required to
// invoke it, then call the implementation.
if un.impl != nil && (un.trait == 0 || (!strict && types.IsUnknownOrError(argVal)) || argVal.Type().HasTrait(un.trait)) {
return un.impl(argVal)
return types.LabelErrNode(un.id, un.impl(argVal))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if argVal.Type().HasTrait(traits.ReceiverType) {
return argVal.(traits.Receiver).Receive(un.function, un.overload, []ref.Val{})
return types.LabelErrNode(un.id, argVal.(traits.Receiver).Receive(un.function, un.overload, []ref.Val{}))
}
return types.NewErr("no such overload: %s", un.function)
return types.NewErrWithNodeID(un.id, "no such overload: %s", un.function)
}
// Function implements the InterpretableCall interface method.
@ -479,14 +481,14 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val {
// If the implementation is bound and the argument value has the right traits required to
// invoke it, then call the implementation.
if bin.impl != nil && (bin.trait == 0 || (!strict && types.IsUnknownOrError(lVal)) || lVal.Type().HasTrait(bin.trait)) {
return bin.impl(lVal, rVal)
return types.LabelErrNode(bin.id, bin.impl(lVal, rVal))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if lVal.Type().HasTrait(traits.ReceiverType) {
return lVal.(traits.Receiver).Receive(bin.function, bin.overload, []ref.Val{rVal})
return types.LabelErrNode(bin.id, lVal.(traits.Receiver).Receive(bin.function, bin.overload, []ref.Val{rVal}))
}
return types.NewErr("no such overload: %s", bin.function)
return types.NewErrWithNodeID(bin.id, "no such overload: %s", bin.function)
}
// Function implements the InterpretableCall interface method.
@ -545,14 +547,14 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val {
// invoke it, then call the implementation.
arg0 := argVals[0]
if fn.impl != nil && (fn.trait == 0 || (!strict && types.IsUnknownOrError(arg0)) || arg0.Type().HasTrait(fn.trait)) {
return fn.impl(argVals...)
return types.LabelErrNode(fn.id, fn.impl(argVals...))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if arg0.Type().HasTrait(traits.ReceiverType) {
return arg0.(traits.Receiver).Receive(fn.function, fn.overload, argVals[1:])
return types.LabelErrNode(fn.id, arg0.(traits.Receiver).Receive(fn.function, fn.overload, argVals[1:]))
}
return types.NewErr("no such overload: %s", fn.function)
return types.NewErrWithNodeID(fn.id, "no such overload: %s %d", fn.function, fn.id)
}
// Function implements the InterpretableCall interface method.
@ -595,7 +597,7 @@ func (l *evalList) Eval(ctx Activation) ref.Val {
if l.hasOptionals && l.optionals[i] {
optVal, ok := elemVal.(*types.Optional)
if !ok {
return invalidOptionalElementInit(elemVal)
return types.LabelErrNode(l.id, invalidOptionalElementInit(elemVal))
}
if !optVal.HasValue() {
continue
@ -645,7 +647,7 @@ func (m *evalMap) Eval(ctx Activation) ref.Val {
if m.hasOptionals && m.optionals[i] {
optVal, ok := valVal.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(keyVal, valVal)
return types.LabelErrNode(m.id, invalidOptionalEntryInit(keyVal, valVal))
}
if !optVal.HasValue() {
delete(entries, keyVal)
@ -705,7 +707,7 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
if o.hasOptionals && o.optionals[i] {
optVal, ok := val.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(field, val)
return types.LabelErrNode(o.id, invalidOptionalEntryInit(field, val))
}
if !optVal.HasValue() {
delete(fieldVals, field)
@ -715,7 +717,7 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
}
fieldVals[field] = val
}
return o.provider.NewValue(o.typeName, fieldVals)
return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals))
}
func (o *evalObj) InitVals() []Interpretable {
@ -921,7 +923,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.ConstantQualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -934,7 +936,7 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence
out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -964,7 +966,7 @@ func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Attribute.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -977,7 +979,7 @@ func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceO
out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -1001,7 +1003,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Qualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -1014,7 +1016,7 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly
out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -1157,12 +1159,12 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val {
}
if cBool {
if tErr != nil {
return types.WrapErr(tErr)
return types.LabelErrNode(cond.id, types.WrapErr(tErr))
}
return cond.adapter.NativeToValue(tVal)
}
if fErr != nil {
return types.WrapErr(fErr)
return types.LabelErrNode(cond.id, types.WrapErr(fErr))
}
return cond.adapter.NativeToValue(fVal)
}
@ -1202,7 +1204,7 @@ func (a *evalAttr) Adapter() types.Adapter {
func (a *evalAttr) Eval(ctx Activation) ref.Val {
v, err := a.attr.Resolve(ctx)
if err != nil {
return types.WrapErr(err)
return types.LabelErrNode(a.ID(), types.WrapErr(err))
}
return a.adapter.NativeToValue(v)
}

View File

@ -22,19 +22,13 @@ import (
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Interpreter generates a new Interpretable from a checked or unchecked expression.
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error)
// NewUncheckedInterpretable returns an Interpretable from a parsed expression
// and an optional list of InterpretableDecorator values.
NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error)
NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
@ -177,7 +171,7 @@ func NewInterpreter(dispatcher Dispatcher,
// NewIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewInterpretable(
checked *ast.CheckedAST,
checked *ast.AST,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newPlanner(
i.dispatcher,
@ -187,19 +181,5 @@ func (i *exprInterpreter) NewInterpretable(
i.container,
checked,
decorators...)
return p.Plan(checked.Expr)
}
// NewUncheckedIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewUncheckedInterpretable(
expr *exprpb.Expr,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newUncheckedPlanner(
i.dispatcher,
i.provider,
i.adapter,
i.attrFactory,
i.container,
decorators...)
return p.Plan(expr)
return p.Plan(checked.Expr())
}

View File

@ -23,15 +23,12 @@ import (
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value.
type interpretablePlanner interface {
// Plan generates an Interpretable value (or error) from the input proto Expr.
Plan(expr *exprpb.Expr) (Interpretable, error)
Plan(expr ast.Expr) (Interpretable, error)
}
// newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider,
@ -43,7 +40,7 @@ func newPlanner(disp Dispatcher,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
checked *ast.CheckedAST,
exprAST *ast.AST,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
@ -51,29 +48,8 @@ func newPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: checked.ReferenceMap,
typeMap: checked.TypeMap,
decorators: decorators,
}
}
// newUncheckedPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider,
// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in
// Select expressions are resolved lazily at evaluation time.
func newUncheckedPlanner(disp Dispatcher,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
provider: provider,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: make(map[int64]*ast.ReferenceInfo),
typeMap: make(map[int64]*types.Type),
refMap: exprAST.ReferenceMap(),
typeMap: exprAST.TypeMap(),
decorators: decorators,
}
}
@ -95,22 +71,24 @@ type planner struct {
// useful for layering functionality into the evaluation that is not natively understood by CEL,
// such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of
// repeated expressions.
func (p *planner) Plan(expr *exprpb.Expr) (Interpretable, error) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
func (p *planner) Plan(expr ast.Expr) (Interpretable, error) {
switch expr.Kind() {
case ast.CallKind:
return p.decorate(p.planCall(expr))
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
return p.decorate(p.planIdent(expr))
case *exprpb.Expr_SelectExpr:
return p.decorate(p.planSelect(expr))
case *exprpb.Expr_ListExpr:
return p.decorate(p.planCreateList(expr))
case *exprpb.Expr_StructExpr:
return p.decorate(p.planCreateStruct(expr))
case *exprpb.Expr_ComprehensionExpr:
return p.decorate(p.planComprehension(expr))
case *exprpb.Expr_ConstExpr:
case ast.LiteralKind:
return p.decorate(p.planConst(expr))
case ast.SelectKind:
return p.decorate(p.planSelect(expr))
case ast.ListKind:
return p.decorate(p.planCreateList(expr))
case ast.MapKind:
return p.decorate(p.planCreateMap(expr))
case ast.StructKind:
return p.decorate(p.planCreateStruct(expr))
case ast.ComprehensionKind:
return p.decorate(p.planComprehension(expr))
}
return nil, fmt.Errorf("unsupported expr: %v", expr)
}
@ -132,16 +110,16 @@ func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) {
}
// planIdent creates an Interpretable that resolves an identifier from an Activation.
func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) {
func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) {
// Establish whether the identifier is in the reference map.
if identRef, found := p.refMap[expr.GetId()]; found {
return p.planCheckedIdent(expr.GetId(), identRef)
if identRef, found := p.refMap[expr.ID()]; found {
return p.planCheckedIdent(expr.ID(), identRef)
}
// Create the possible attribute list for the unresolved reference.
ident := expr.GetIdentExpr()
ident := expr.AsIdent()
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.MaybeAttribute(expr.GetId(), ident.Name),
attr: p.attrFactory.MaybeAttribute(expr.ID(), ident),
}, nil
}
@ -174,20 +152,20 @@ func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Inter
// a) selects a field from a map or proto.
// b) creates a field presence test for a select within a has() macro.
// c) resolves the select expression to a namespaced identifier.
func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) {
// If the Select id appears in the reference map from the CheckedExpr proto then it is either
// a namespaced identifier or enum value.
if identRef, found := p.refMap[expr.GetId()]; found {
return p.planCheckedIdent(expr.GetId(), identRef)
if identRef, found := p.refMap[expr.ID()]; found {
return p.planCheckedIdent(expr.ID(), identRef)
}
sel := expr.GetSelectExpr()
sel := expr.AsSelect()
// Plan the operand evaluation.
op, err := p.Plan(sel.GetOperand())
op, err := p.Plan(sel.Operand())
if err != nil {
return nil, err
}
opType := p.typeMap[sel.GetOperand().GetId()]
opType := p.typeMap[sel.Operand().ID()]
// If the Select was marked TestOnly, this is a presence test.
//
@ -211,14 +189,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
}
// Build a qualifier for the attribute.
qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false)
qual, err := p.attrFactory.NewQualifier(opType, expr.ID(), sel.FieldName(), false)
if err != nil {
return nil, err
}
// Modify the attribute to be test-only.
if sel.GetTestOnly() {
if sel.IsTestOnly() {
attr = &evalTestOnly{
id: expr.GetId(),
id: expr.ID(),
InterpretableAttribute: attr,
}
}
@ -230,10 +208,10 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
// planCall creates a callable Interpretable while specializing for common functions and invocation
// patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in
// optimized Interpretable values.
func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
call := expr.GetCallExpr()
func (p *planner) planCall(expr ast.Expr) (Interpretable, error) {
call := expr.AsCall()
target, fnName, oName := p.resolveFunction(expr)
argCount := len(call.GetArgs())
argCount := len(call.Args())
var offset int
if target != nil {
argCount++
@ -248,7 +226,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
}
args[0] = arg
}
for i, argExpr := range call.GetArgs() {
for i, argExpr := range call.Args() {
arg, err := p.Plan(argExpr)
if err != nil {
return nil, err
@ -307,7 +285,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
}
// planCallZero generates a zero-arity callable Interpretable.
func (p *planner) planCallZero(expr *exprpb.Expr,
func (p *planner) planCallZero(expr ast.Expr,
function string,
overload string,
impl *functions.Overload) (Interpretable, error) {
@ -315,7 +293,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr,
return nil, fmt.Errorf("no such overload: %s()", function)
}
return &evalZeroArity{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
impl: impl.Function,
@ -323,7 +301,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr,
}
// planCallUnary generates a unary callable Interpretable.
func (p *planner) planCallUnary(expr *exprpb.Expr,
func (p *planner) planCallUnary(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -340,7 +318,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalUnary{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
arg: args[0],
@ -351,7 +329,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr,
}
// planCallBinary generates a binary callable Interpretable.
func (p *planner) planCallBinary(expr *exprpb.Expr,
func (p *planner) planCallBinary(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -368,7 +346,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalBinary{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
lhs: args[0],
@ -380,7 +358,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr,
}
// planCallVarArgs generates a variable argument callable Interpretable.
func (p *planner) planCallVarArgs(expr *exprpb.Expr,
func (p *planner) planCallVarArgs(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -397,7 +375,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalVarArgs{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
args: args,
@ -408,41 +386,41 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
}
// planCallEqual generates an equals (==) Interpretable.
func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalEq{
id: expr.GetId(),
id: expr.ID(),
lhs: args[0],
rhs: args[1],
}, nil
}
// planCallNotEqual generates a not equals (!=) Interpretable.
func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalNe{
id: expr.GetId(),
id: expr.ID(),
lhs: args[0],
rhs: args[1],
}, nil
}
// planCallLogicalAnd generates a logical and (&&) Interpretable.
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalAnd{
id: expr.GetId(),
id: expr.ID(),
terms: args,
}, nil
}
// planCallLogicalOr generates a logical or (||) Interpretable.
func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalOr{
id: expr.GetId(),
id: expr.ID(),
terms: args,
}, nil
}
// planCallConditional generates a conditional / ternary (c ? t : f) Interpretable.
func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) {
cond := args[0]
t := args[1]
var tAttr Attribute
@ -464,13 +442,13 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.ConditionalAttribute(expr.GetId(), cond, tAttr, fAttr),
attr: p.attrFactory.ConditionalAttribute(expr.ID(), cond, tAttr, fAttr),
}, nil
}
// planCallIndex either extends an attribute with the argument to the index operation, or creates
// a relative attribute based on the return of a function call or operation.
func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) {
func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) {
op := args[0]
ind := args[1]
opType := p.typeMap[op.ID()]
@ -489,11 +467,11 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona
var qual Qualifier
switch ind := ind.(type) {
case InterpretableConst:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional)
qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind.Value(), optional)
case InterpretableAttribute:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional)
qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind, optional)
default:
qual, err = p.relativeAttr(expr.GetId(), ind, optional)
qual, err = p.relativeAttr(expr.ID(), ind, optional)
}
if err != nil {
return nil, err
@ -505,10 +483,10 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona
}
// planCreateList generates a list construction Interpretable.
func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
list := expr.GetListExpr()
optionalIndices := list.GetOptionalIndices()
elements := list.GetElements()
func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) {
list := expr.AsList()
optionalIndices := list.OptionalIndices()
elements := list.Elements()
optionals := make([]bool, len(elements))
for _, index := range optionalIndices {
if index < 0 || index >= int32(len(elements)) {
@ -525,7 +503,7 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
elems[i] = elemVal
}
return &evalList{
id: expr.GetId(),
id: expr.ID(),
elems: elems,
optionals: optionals,
hasOptionals: len(optionals) != 0,
@ -534,31 +512,29 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
}
// planCreateStruct generates a map or object construction Interpretable.
func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
str := expr.GetStructExpr()
if len(str.MessageName) != 0 {
return p.planCreateObj(expr)
}
entries := str.GetEntries()
func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) {
m := expr.AsMap()
entries := m.Entries()
optionals := make([]bool, len(entries))
keys := make([]Interpretable, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
keyVal, err := p.Plan(entry.GetMapKey())
for i, e := range entries {
entry := e.AsMapEntry()
keyVal, err := p.Plan(entry.Key())
if err != nil {
return nil, err
}
keys[i] = keyVal
valVal, err := p.Plan(entry.GetValue())
valVal, err := p.Plan(entry.Value())
if err != nil {
return nil, err
}
vals[i] = valVal
optionals[i] = entry.GetOptionalEntry()
optionals[i] = entry.IsOptional()
}
return &evalMap{
id: expr.GetId(),
id: expr.ID(),
keys: keys,
vals: vals,
optionals: optionals,
@ -568,27 +544,28 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
}
// planCreateObj generates an object construction Interpretable.
func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
obj := expr.GetStructExpr()
typeName, defined := p.resolveTypeName(obj.GetMessageName())
func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) {
obj := expr.AsStruct()
typeName, defined := p.resolveTypeName(obj.TypeName())
if !defined {
return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName())
return nil, fmt.Errorf("unknown type: %s", obj.TypeName())
}
entries := obj.GetEntries()
optionals := make([]bool, len(entries))
fields := make([]string, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
fields[i] = entry.GetFieldKey()
val, err := p.Plan(entry.GetValue())
objFields := obj.Fields()
optionals := make([]bool, len(objFields))
fields := make([]string, len(objFields))
vals := make([]Interpretable, len(objFields))
for i, f := range objFields {
field := f.AsStructField()
fields[i] = field.Name()
val, err := p.Plan(field.Value())
if err != nil {
return nil, err
}
vals[i] = val
optionals[i] = entry.GetOptionalEntry()
optionals[i] = field.IsOptional()
}
return &evalObj{
id: expr.GetId(),
id: expr.ID(),
typeName: typeName,
fields: fields,
vals: vals,
@ -599,33 +576,33 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
}
// planComprehension generates an Interpretable fold operation.
func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
fold := expr.GetComprehensionExpr()
accu, err := p.Plan(fold.GetAccuInit())
func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) {
fold := expr.AsComprehension()
accu, err := p.Plan(fold.AccuInit())
if err != nil {
return nil, err
}
iterRange, err := p.Plan(fold.GetIterRange())
iterRange, err := p.Plan(fold.IterRange())
if err != nil {
return nil, err
}
cond, err := p.Plan(fold.GetLoopCondition())
cond, err := p.Plan(fold.LoopCondition())
if err != nil {
return nil, err
}
step, err := p.Plan(fold.GetLoopStep())
step, err := p.Plan(fold.LoopStep())
if err != nil {
return nil, err
}
result, err := p.Plan(fold.GetResult())
result, err := p.Plan(fold.Result())
if err != nil {
return nil, err
}
return &evalFold{
id: expr.GetId(),
accuVar: fold.AccuVar,
id: expr.ID(),
accuVar: fold.AccuVar(),
accu: accu,
iterVar: fold.IterVar,
iterVar: fold.IterVar(),
iterRange: iterRange,
cond: cond,
step: step,
@ -635,37 +612,8 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
}
// planConst generates a constant valued Interpretable.
func (p *planner) planConst(expr *exprpb.Expr) (Interpretable, error) {
val, err := p.constValue(expr.GetConstExpr())
if err != nil {
return nil, err
}
return NewConstValue(expr.GetId(), val), nil
}
// constValue converts a proto Constant value to a ref.Val.
func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return p.adapter.NativeToValue(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return p.adapter.NativeToValue(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return p.adapter.NativeToValue(c.GetDoubleValue()), nil
case *exprpb.Constant_DurationValue:
return p.adapter.NativeToValue(c.GetDurationValue().AsDuration()), nil
case *exprpb.Constant_Int64Value:
return p.adapter.NativeToValue(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return p.adapter.NativeToValue(c.GetNullValue()), nil
case *exprpb.Constant_StringValue:
return p.adapter.NativeToValue(c.GetStringValue()), nil
case *exprpb.Constant_TimestampValue:
return p.adapter.NativeToValue(c.GetTimestampValue().AsTime()), nil
case *exprpb.Constant_Uint64Value:
return p.adapter.NativeToValue(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unknown constant type: %v", c)
func (p *planner) planConst(expr ast.Expr) (Interpretable, error) {
return NewConstValue(expr.ID(), expr.AsLiteral()), nil
}
// resolveTypeName takes a qualified string constructed at parse time, applies the proto
@ -687,17 +635,20 @@ func (p *planner) resolveTypeName(typeName string) (string, bool) {
// - The target expression may only consist of ident and select expressions.
// - The function is declared in the environment using its fully-qualified name.
// - The fully-qualified function name matches the string serialized target value.
func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, string) {
func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) {
// Note: similar logic exists within the `checker/checker.go`. If making changes here
// please consider the impact on checker.go and consolidate implementations or mirror code
// as appropriate.
call := expr.GetCallExpr()
target := call.GetTarget()
fnName := call.GetFunction()
call := expr.AsCall()
var target ast.Expr = nil
if call.IsMemberFunction() {
target = call.Target()
}
fnName := call.FunctionName()
// Checked expressions always have a reference map entry, and _should_ have the fully qualified
// function name as the fnName value.
oRef, hasOverload := p.refMap[expr.GetId()]
oRef, hasOverload := p.refMap[expr.ID()]
if hasOverload {
if len(oRef.OverloadIDs) == 1 {
return target, fnName, oRef.OverloadIDs[0]
@ -771,16 +722,30 @@ func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (Interpre
// toQualifiedName converts an expression AST into a qualified name if possible, with a boolean
// 'found' value that indicates if the conversion is successful.
func (p *planner) toQualifiedName(operand *exprpb.Expr) (string, bool) {
func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) {
// If the checker identified the expression as an attribute by the type-checker, then it can't
// possibly be part of qualified name in a namespace.
_, isAttr := p.refMap[operand.GetId()]
_, isAttr := p.refMap[operand.ID()]
if isAttr {
return "", false
}
// Since functions cannot be both namespaced and receiver functions, if the operand is not an
// qualified variable name, return the (possibly) qualified name given the expressions.
return containers.ToQualifiedName(operand)
switch operand.Kind() {
case ast.IdentKind:
id := operand.AsIdent()
return id, true
case ast.SelectKind:
sel := operand.AsSelect()
// Test only expressions are not valid as qualified names.
if sel.IsTestOnly() {
return "", false
}
if qual, found := p.toQualifiedName(sel.Operand()); found {
return qual + "." + sel.FieldName(), true
}
}
return "", false
}
func stripLeadingDot(name string) string {

View File

@ -15,19 +15,18 @@
package interpreter
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
type astPruner struct {
expr *exprpb.Expr
macroCalls map[int64]*exprpb.Expr
ast.ExprFactory
expr ast.Expr
macroCalls map[int64]ast.Expr
state EvalState
nextExprID int64
}
@ -67,84 +66,44 @@ type astPruner struct {
// compiled and constant folded expressions, but is not willing to constant
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
ExprFactory: ast.NewExprFactory(),
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expr: newExpr,
SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls},
newInfo := ast.NewSourceInfo(nil)
for id, call := range pruner.macroCalls {
newInfo.SetMacroCall(id, call)
}
return ast.NewAST(newExpr, newInfo)
}
func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: val,
},
}
}
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) {
switch v := val.(type) {
case types.Bool:
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
return p.NewLiteral(id, val), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
durationString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true
case types.Timestamp:
timestampString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true
}
// Attempt to build a list literal.
if list, isList := val.(traits.Lister); isList {
sz := list.Size().(types.Int)
elemExprs := make([]*exprpb.Expr, sz)
elemExprs := make([]ast.Expr, sz)
for i := types.Int(0); i < sz; i++ {
elem := list.Get(i)
if types.IsUnknownOrError(elem) {
@ -157,20 +116,13 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elemExprs,
},
},
}, true
return p.NewList(id, elemExprs, []int32{}), true
}
// Create a map literal if possible.
if mp, isMap := val.(traits.Mapper); isMap {
it := mp.Iterator()
entries := make([]*exprpb.Expr_CreateStruct_Entry, mp.Size().(types.Int))
entries := make([]ast.EntryExpr, mp.Size().(types.Int))
i := 0
for it.HasNext() != types.False {
key := it.Next()
@ -186,25 +138,12 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
if !ok {
return nil, false
}
entry := &exprpb.Expr_CreateStruct_Entry{
Id: p.nextID(),
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: keyExpr,
},
Value: valExpr,
}
entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false)
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
Entries: entries,
},
},
}, true
return p.NewMap(id, entries), true
}
// TODO(issues/377) To construct message literals, the type provider will need to support
@ -212,215 +151,206 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) {
elemVal, found := p.value(elem.ID())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}
func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) {
// elem in list
call := node.GetCallExpr()
val, exists := p.maybeValue(call.GetArgs()[1].GetId())
call := node.AsCall()
val, exists := p.maybeValue(call.Args()[1].ID())
if !exists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return nil, false
}
func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
arg := call.GetArgs()[0]
val, exists := p.maybeValue(arg.GetId())
func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
arg := call.Args()[0]
val, exists := p.maybeValue(arg.ID())
if !exists {
return nil, false
}
if b, ok := val.(types.Bool); ok {
return p.maybeCreateLiteral(node.GetId(), !b)
return p.maybeCreateLiteral(node.ID(), !b)
}
return nil, false
}
func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.GetArgs()[1], true
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.GetArgs()[0], true
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.GetArgs()[1], true
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.GetArgs()[0], true
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
cond, exists := p.maybeValue(call.GetArgs()[0].GetId())
func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
cond, exists := p.maybeValue(call.Args()[0].ID())
if !exists {
return nil, false
}
if cond.Value().(bool) {
return call.GetArgs()[1], true
return call.Args()[1], true
}
return call.GetArgs()[2], true
return call.Args()[2], true
}
func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if _, exists := p.value(node.GetId()); !exists {
func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) {
if _, exists := p.value(node.ID()); !exists {
return nil, false
}
call := node.GetCallExpr()
if call.Function == operators.LogicalOr {
call := node.AsCall()
if call.FunctionName() == operators.LogicalOr {
return p.maybePruneOr(node)
}
if call.Function == operators.LogicalAnd {
if call.FunctionName() == operators.LogicalAnd {
return p.maybePruneAnd(node)
}
if call.Function == operators.Conditional {
if call.FunctionName() == operators.Conditional {
return p.maybePruneConditional(node)
}
if call.Function == operators.In {
if call.FunctionName() == operators.In {
return p.maybePruneIn(node)
}
if call.Function == operators.LogicalNot {
if call.FunctionName() == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}
func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) {
return p.prune(node)
}
func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.maybeValue(node.GetId())
val, valueExists := p.maybeValue(node.ID())
if valueExists {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
delete(p.macroCalls, node.GetId())
if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok {
delete(p.macroCalls, node.ID())
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
if macro, found := p.macroCalls[node.ID()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
compre := node.GetComprehensionExpr()
if compre != nil {
visit(macro, clearIterVarVisitor(compre.IterVar, p.state))
if node.Kind() == ast.ComprehensionKind {
compre := node.AsComprehension()
visit(macro, clearIterVarVisitor(compre.IterVar(), p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
p.macroCalls[node.ID()] = newMacro
}
}
// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
// more.
switch node.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: operand,
Field: node.GetSelectExpr().GetField(),
TestOnly: node.GetSelectExpr().GetTestOnly(),
},
},
}, true
switch node.Kind() {
case ast.SelectKind:
sel := node.AsSelect()
if operand, isPruned := p.maybePrune(sel.Operand()); isPruned {
if sel.IsTestOnly() {
return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true
}
return p.NewSelect(node.ID(), operand, sel.FieldName()), true
}
case *exprpb.Expr_CallExpr:
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
newArgs := make([]*exprpb.Expr, len(args))
newCall := &exprpb.Expr_Call{
Function: call.GetFunction(),
Target: call.GetTarget(),
Args: newArgs,
}
for i, arg := range args {
newArgs[i] = arg
if newArg, prunedArg := p.maybePrune(arg); prunedArg {
prunedCall = true
newArgs[i] = newArg
case ast.CallKind:
argsPruned := false
call := node.AsCall()
args := call.Args()
newArgs := make([]ast.Expr, len(args))
for i, a := range args {
newArgs[i] = a
if arg, isPruned := p.maybePrune(a); isPruned {
argsPruned = true
newArgs[i] = arg
}
}
if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget {
prunedCall = true
newCall.Target = newTarget
if !call.IsMemberFunction() {
newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
return newCall, argsPruned
}
newNode := &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
newTarget := call.Target()
targetPruned := false
if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned {
targetPruned = true
newTarget = prunedTarget
}
if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
if prunedCall {
return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
optIndices := node.GetListExpr().GetOptionalIndices()
return newCall, targetPruned || argsPruned
case ast.ListKind:
l := node.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool
newElems := make([]ast.Expr, 0, len(elems))
var listPruned bool
prunedIdx := 0
for i, elem := range elems {
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
prunedList = true
listPruned = true
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
@ -431,7 +361,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems = append(newElems, newElem)
prunedList = true
listPruned = true
} else {
newElems = append(newElems, elem)
}
@ -443,76 +373,64 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
if listPruned {
return p.NewList(node.ID(), newElems, optIndices), true
}
case *exprpb.Expr_StructExpr:
var prunedStruct bool
entries := node.GetStructExpr().GetEntries()
messageType := node.GetStructExpr().GetMessageName()
newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
case ast.MapKind:
var mapPruned bool
m := node.AsMap()
entries := m.Entries()
newEntries := make([]ast.EntryExpr, len(entries))
for i, entry := range entries {
newEntries[i] = entry
newKey, prunedKey := p.maybePrune(entry.GetMapKey())
newValue, prunedValue := p.maybePrune(entry.GetValue())
if !prunedKey && !prunedValue {
e := entry.AsMapEntry()
newKey, keyPruned := p.maybePrune(e.Key())
newValue, valuePruned := p.maybePrune(e.Value())
if !keyPruned && !valuePruned {
continue
}
prunedStruct = true
newEntry := &exprpb.Expr_CreateStruct_Entry{
Value: newValue,
}
if messageType != "" {
newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: entry.GetFieldKey(),
}
} else {
newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
mapPruned = true
newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional())
newEntries[i] = newEntry
}
if prunedStruct {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: messageType,
Entries: newEntries,
},
},
}, true
if mapPruned {
return p.NewMap(node.ID(), newEntries), true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
case ast.StructKind:
var structPruned bool
obj := node.AsStruct()
fields := obj.Fields()
newFields := make([]ast.EntryExpr, len(fields))
for i, field := range fields {
newFields[i] = field
f := field.AsStructField()
newValue, prunedValue := p.maybePrune(f.Value())
if !prunedValue {
continue
}
structPruned = true
newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional())
newFields[i] = newEntry
}
if structPruned {
return p.NewStruct(node.ID(), obj.TypeName(), newFields), true
}
case ast.ComprehensionKind:
compre := node.AsComprehension()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
if newRange, pruned := p.maybePrune(compre.IterRange()); pruned {
return p.NewComprehension(
node.ID(),
newRange,
compre.IterVar(),
compre.AccuVar(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result(),
), true
}
}
return node, false
@ -539,12 +457,12 @@ func (p *astPruner) nextID() int64 {
type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
visitExpr func(expr ast.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
visitEntry func(entry ast.EntryExpr)
}
func getMaxID(expr *exprpb.Expr) int64 {
func getMaxID(expr ast.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
@ -552,10 +470,9 @@ func getMaxID(expr *exprpb.Expr) int64 {
func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
ident := e.GetIdentExpr()
if ident != nil && ident.GetName() == varName {
state.SetValue(e.GetId(), nil)
visitExpr: func(e ast.Expr) {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
state.SetValue(e.ID(), nil)
}
},
}
@ -563,56 +480,63 @@ func clearIterVarVisitor(varName string, state EvalState) astVisitor {
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
visitExpr: func(e ast.Expr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
visitEntry: func(e ast.EntryExpr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
}
}
func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
func visit(expr ast.Expr, visitor astVisitor) {
exprs := []ast.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
switch e.Kind() {
case ast.SelectKind:
exprs = append(exprs, e.AsSelect().Operand())
case ast.CallKind:
call := e.AsCall()
if call.Target() != nil {
exprs = append(exprs, call.Target())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs, call.Args()...)
case ast.ComprehensionKind:
compre := e.AsComprehension()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
compre.IterRange(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result())
case ast.ListKind:
list := e.AsList()
exprs = append(exprs, list.Elements()...)
case ast.MapKind:
for _, entry := range e.AsMap().Entries() {
e := entry.AsMapEntry()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
exprs = append(exprs, e.Key())
exprs = append(exprs, e.Value())
}
case ast.StructKind:
for _, entry := range e.AsStruct().Fields() {
f := entry.AsStructField()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
exprs = append(exprs, entry.GetValue())
exprs = append(exprs, f.Value())
}
}
}

View File

@ -20,10 +20,13 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/operators:go_default_library",
"//common/runes:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//parser/gen:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@com_github_antlr4_go_antlr_v4//:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@ -43,10 +46,12 @@ go_test(
":go_default_library",
],
deps = [
"//common/ast:go_default_library",
"//common/debug:go_default_library",
"//common/types:go_default_library",
"//parser/gen:go_default_library",
"//test:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@com_github_antlr4_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//testing/protocmp:go_default_library",
],

View File

@ -21,6 +21,6 @@ go_library(
],
importpath = "github.com/google/cel-go/parser/gen",
deps = [
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@com_github_antlr4_go_antlr_v4//:go_default_library",
],
)

View File

@ -1,7 +1,7 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
// Code generated from /usr/local/google/home/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.13.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
import "github.com/antlr4-go/antlr/v4"
// BaseCELListener is a complete listener for a parse tree produced by CELParser.
type BaseCELListener struct{}

View File

@ -1,7 +1,8 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
// Code generated from /usr/local/google/home/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.13.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
import "github.com/antlr4-go/antlr/v4"
type BaseCELVisitor struct {
*antlr.BaseParseTreeVisitor

View File

@ -1,280 +1,278 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
// Code generated from /usr/local/google/home/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.13.1. DO NOT EDIT.
package gen
import (
"fmt"
"sync"
"sync"
"unicode"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/antlr4-go/antlr/v4"
)
// Suppress unused import error
var _ = fmt.Printf
var _ = sync.Once{}
var _ = unicode.IsLetter
type CELLexer struct {
*antlr.BaseLexer
channelNames []string
modeNames []string
modeNames []string
// TODO: EOF string
}
var cellexerLexerStaticData struct {
once sync.Once
serializedATN []int32
channelNames []string
modeNames []string
literalNames []string
symbolicNames []string
ruleNames []string
predictionContextCache *antlr.PredictionContextCache
atn *antlr.ATN
decisionToDFA []*antlr.DFA
var CELLexerLexerStaticData struct {
once sync.Once
serializedATN []int32
ChannelNames []string
ModeNames []string
LiteralNames []string
SymbolicNames []string
RuleNames []string
PredictionContextCache *antlr.PredictionContextCache
atn *antlr.ATN
decisionToDFA []*antlr.DFA
}
func cellexerLexerInit() {
staticData := &cellexerLexerStaticData
staticData.channelNames = []string{
"DEFAULT_TOKEN_CHANNEL", "HIDDEN",
}
staticData.modeNames = []string{
"DEFAULT_MODE",
}
staticData.literalNames = []string{
"", "'=='", "'!='", "'in'", "'<'", "'<='", "'>='", "'>'", "'&&'", "'||'",
"'['", "']'", "'{'", "'}'", "'('", "')'", "'.'", "','", "'-'", "'!'",
"'?'", "':'", "'+'", "'*'", "'/'", "'%'", "'true'", "'false'", "'null'",
}
staticData.symbolicNames = []string{
"", "EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE",
"NUL", "WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT",
"STRING", "BYTES", "IDENTIFIER",
}
staticData.ruleNames = []string{
"EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE",
"NUL", "BACKSLASH", "LETTER", "DIGIT", "EXPONENT", "HEXDIGIT", "RAW",
"ESC_SEQ", "ESC_CHAR_SEQ", "ESC_OCT_SEQ", "ESC_BYTE_SEQ", "ESC_UNI_SEQ",
"WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING",
"BYTES", "IDENTIFIER",
}
staticData.predictionContextCache = antlr.NewPredictionContextCache()
staticData.serializedATN = []int32{
4, 0, 36, 423, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2,
4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2,
10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15,
7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 2, 20, 7,
20, 2, 21, 7, 21, 2, 22, 7, 22, 2, 23, 7, 23, 2, 24, 7, 24, 2, 25, 7, 25,
2, 26, 7, 26, 2, 27, 7, 27, 2, 28, 7, 28, 2, 29, 7, 29, 2, 30, 7, 30, 2,
31, 7, 31, 2, 32, 7, 32, 2, 33, 7, 33, 2, 34, 7, 34, 2, 35, 7, 35, 2, 36,
7, 36, 2, 37, 7, 37, 2, 38, 7, 38, 2, 39, 7, 39, 2, 40, 7, 40, 2, 41, 7,
41, 2, 42, 7, 42, 2, 43, 7, 43, 2, 44, 7, 44, 2, 45, 7, 45, 2, 46, 7, 46,
1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4,
1, 4, 1, 4, 1, 5, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 7, 1, 8, 1, 8,
1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13,
1, 14, 1, 14, 1, 15, 1, 15, 1, 16, 1, 16, 1, 17, 1, 17, 1, 18, 1, 18, 1,
19, 1, 19, 1, 20, 1, 20, 1, 21, 1, 21, 1, 22, 1, 22, 1, 23, 1, 23, 1, 24,
1, 24, 1, 25, 1, 25, 1, 25, 1, 25, 1, 25, 1, 26, 1, 26, 1, 26, 1, 26, 1,
26, 1, 26, 1, 27, 1, 27, 1, 27, 1, 27, 1, 27, 1, 28, 1, 28, 1, 29, 1, 29,
1, 30, 1, 30, 1, 31, 1, 31, 3, 31, 177, 8, 31, 1, 31, 4, 31, 180, 8, 31,
11, 31, 12, 31, 181, 1, 32, 1, 32, 1, 33, 1, 33, 1, 34, 1, 34, 1, 34, 1,
34, 3, 34, 192, 8, 34, 1, 35, 1, 35, 1, 35, 1, 36, 1, 36, 1, 36, 1, 36,
1, 36, 1, 37, 1, 37, 1, 37, 1, 37, 1, 37, 1, 38, 1, 38, 1, 38, 1, 38, 1,
38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38,
1, 38, 1, 38, 1, 38, 3, 38, 225, 8, 38, 1, 39, 4, 39, 228, 8, 39, 11, 39,
12, 39, 229, 1, 39, 1, 39, 1, 40, 1, 40, 1, 40, 1, 40, 5, 40, 238, 8, 40,
10, 40, 12, 40, 241, 9, 40, 1, 40, 1, 40, 1, 41, 4, 41, 246, 8, 41, 11,
41, 12, 41, 247, 1, 41, 1, 41, 4, 41, 252, 8, 41, 11, 41, 12, 41, 253,
1, 41, 3, 41, 257, 8, 41, 1, 41, 4, 41, 260, 8, 41, 11, 41, 12, 41, 261,
1, 41, 1, 41, 1, 41, 1, 41, 4, 41, 268, 8, 41, 11, 41, 12, 41, 269, 1,
41, 3, 41, 273, 8, 41, 3, 41, 275, 8, 41, 1, 42, 4, 42, 278, 8, 42, 11,
42, 12, 42, 279, 1, 42, 1, 42, 1, 42, 1, 42, 4, 42, 286, 8, 42, 11, 42,
12, 42, 287, 3, 42, 290, 8, 42, 1, 43, 4, 43, 293, 8, 43, 11, 43, 12, 43,
294, 1, 43, 1, 43, 1, 43, 1, 43, 1, 43, 1, 43, 4, 43, 303, 8, 43, 11, 43,
12, 43, 304, 1, 43, 1, 43, 3, 43, 309, 8, 43, 1, 44, 1, 44, 1, 44, 5, 44,
314, 8, 44, 10, 44, 12, 44, 317, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5,
44, 323, 8, 44, 10, 44, 12, 44, 326, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44,
1, 44, 1, 44, 1, 44, 5, 44, 335, 8, 44, 10, 44, 12, 44, 338, 9, 44, 1,
44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5, 44, 349,
8, 44, 10, 44, 12, 44, 352, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1,
44, 5, 44, 360, 8, 44, 10, 44, 12, 44, 363, 9, 44, 1, 44, 1, 44, 1, 44,
1, 44, 1, 44, 5, 44, 370, 8, 44, 10, 44, 12, 44, 373, 9, 44, 1, 44, 1,
44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5, 44, 383, 8, 44, 10, 44,
12, 44, 386, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1,
44, 1, 44, 1, 44, 5, 44, 398, 8, 44, 10, 44, 12, 44, 401, 9, 44, 1, 44,
1, 44, 1, 44, 1, 44, 3, 44, 407, 8, 44, 1, 45, 1, 45, 1, 45, 1, 46, 1,
46, 3, 46, 414, 8, 46, 1, 46, 1, 46, 1, 46, 5, 46, 419, 8, 46, 10, 46,
12, 46, 422, 9, 46, 4, 336, 350, 384, 399, 0, 47, 1, 1, 3, 2, 5, 3, 7,
4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 12, 25, 13, 27,
14, 29, 15, 31, 16, 33, 17, 35, 18, 37, 19, 39, 20, 41, 21, 43, 22, 45,
23, 47, 24, 49, 25, 51, 26, 53, 27, 55, 28, 57, 0, 59, 0, 61, 0, 63, 0,
65, 0, 67, 0, 69, 0, 71, 0, 73, 0, 75, 0, 77, 0, 79, 29, 81, 30, 83, 31,
85, 32, 87, 33, 89, 34, 91, 35, 93, 36, 1, 0, 16, 2, 0, 65, 90, 97, 122,
2, 0, 69, 69, 101, 101, 2, 0, 43, 43, 45, 45, 3, 0, 48, 57, 65, 70, 97,
102, 2, 0, 82, 82, 114, 114, 10, 0, 34, 34, 39, 39, 63, 63, 92, 92, 96,
98, 102, 102, 110, 110, 114, 114, 116, 116, 118, 118, 2, 0, 88, 88, 120,
120, 3, 0, 9, 10, 12, 13, 32, 32, 1, 0, 10, 10, 2, 0, 85, 85, 117, 117,
4, 0, 10, 10, 13, 13, 34, 34, 92, 92, 4, 0, 10, 10, 13, 13, 39, 39, 92,
92, 1, 0, 92, 92, 3, 0, 10, 10, 13, 13, 34, 34, 3, 0, 10, 10, 13, 13, 39,
39, 2, 0, 66, 66, 98, 98, 456, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5,
1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13,
1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0,
21, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0,
0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0,
0, 0, 37, 1, 0, 0, 0, 0, 39, 1, 0, 0, 0, 0, 41, 1, 0, 0, 0, 0, 43, 1, 0,
0, 0, 0, 45, 1, 0, 0, 0, 0, 47, 1, 0, 0, 0, 0, 49, 1, 0, 0, 0, 0, 51, 1,
0, 0, 0, 0, 53, 1, 0, 0, 0, 0, 55, 1, 0, 0, 0, 0, 79, 1, 0, 0, 0, 0, 81,
1, 0, 0, 0, 0, 83, 1, 0, 0, 0, 0, 85, 1, 0, 0, 0, 0, 87, 1, 0, 0, 0, 0,
89, 1, 0, 0, 0, 0, 91, 1, 0, 0, 0, 0, 93, 1, 0, 0, 0, 1, 95, 1, 0, 0, 0,
3, 98, 1, 0, 0, 0, 5, 101, 1, 0, 0, 0, 7, 104, 1, 0, 0, 0, 9, 106, 1, 0,
0, 0, 11, 109, 1, 0, 0, 0, 13, 112, 1, 0, 0, 0, 15, 114, 1, 0, 0, 0, 17,
117, 1, 0, 0, 0, 19, 120, 1, 0, 0, 0, 21, 122, 1, 0, 0, 0, 23, 124, 1,
0, 0, 0, 25, 126, 1, 0, 0, 0, 27, 128, 1, 0, 0, 0, 29, 130, 1, 0, 0, 0,
31, 132, 1, 0, 0, 0, 33, 134, 1, 0, 0, 0, 35, 136, 1, 0, 0, 0, 37, 138,
1, 0, 0, 0, 39, 140, 1, 0, 0, 0, 41, 142, 1, 0, 0, 0, 43, 144, 1, 0, 0,
0, 45, 146, 1, 0, 0, 0, 47, 148, 1, 0, 0, 0, 49, 150, 1, 0, 0, 0, 51, 152,
1, 0, 0, 0, 53, 157, 1, 0, 0, 0, 55, 163, 1, 0, 0, 0, 57, 168, 1, 0, 0,
0, 59, 170, 1, 0, 0, 0, 61, 172, 1, 0, 0, 0, 63, 174, 1, 0, 0, 0, 65, 183,
1, 0, 0, 0, 67, 185, 1, 0, 0, 0, 69, 191, 1, 0, 0, 0, 71, 193, 1, 0, 0,
0, 73, 196, 1, 0, 0, 0, 75, 201, 1, 0, 0, 0, 77, 224, 1, 0, 0, 0, 79, 227,
1, 0, 0, 0, 81, 233, 1, 0, 0, 0, 83, 274, 1, 0, 0, 0, 85, 289, 1, 0, 0,
0, 87, 308, 1, 0, 0, 0, 89, 406, 1, 0, 0, 0, 91, 408, 1, 0, 0, 0, 93, 413,
1, 0, 0, 0, 95, 96, 5, 61, 0, 0, 96, 97, 5, 61, 0, 0, 97, 2, 1, 0, 0, 0,
98, 99, 5, 33, 0, 0, 99, 100, 5, 61, 0, 0, 100, 4, 1, 0, 0, 0, 101, 102,
5, 105, 0, 0, 102, 103, 5, 110, 0, 0, 103, 6, 1, 0, 0, 0, 104, 105, 5,
60, 0, 0, 105, 8, 1, 0, 0, 0, 106, 107, 5, 60, 0, 0, 107, 108, 5, 61, 0,
0, 108, 10, 1, 0, 0, 0, 109, 110, 5, 62, 0, 0, 110, 111, 5, 61, 0, 0, 111,
12, 1, 0, 0, 0, 112, 113, 5, 62, 0, 0, 113, 14, 1, 0, 0, 0, 114, 115, 5,
38, 0, 0, 115, 116, 5, 38, 0, 0, 116, 16, 1, 0, 0, 0, 117, 118, 5, 124,
0, 0, 118, 119, 5, 124, 0, 0, 119, 18, 1, 0, 0, 0, 120, 121, 5, 91, 0,
0, 121, 20, 1, 0, 0, 0, 122, 123, 5, 93, 0, 0, 123, 22, 1, 0, 0, 0, 124,
125, 5, 123, 0, 0, 125, 24, 1, 0, 0, 0, 126, 127, 5, 125, 0, 0, 127, 26,
1, 0, 0, 0, 128, 129, 5, 40, 0, 0, 129, 28, 1, 0, 0, 0, 130, 131, 5, 41,
0, 0, 131, 30, 1, 0, 0, 0, 132, 133, 5, 46, 0, 0, 133, 32, 1, 0, 0, 0,
134, 135, 5, 44, 0, 0, 135, 34, 1, 0, 0, 0, 136, 137, 5, 45, 0, 0, 137,
36, 1, 0, 0, 0, 138, 139, 5, 33, 0, 0, 139, 38, 1, 0, 0, 0, 140, 141, 5,
63, 0, 0, 141, 40, 1, 0, 0, 0, 142, 143, 5, 58, 0, 0, 143, 42, 1, 0, 0,
0, 144, 145, 5, 43, 0, 0, 145, 44, 1, 0, 0, 0, 146, 147, 5, 42, 0, 0, 147,
46, 1, 0, 0, 0, 148, 149, 5, 47, 0, 0, 149, 48, 1, 0, 0, 0, 150, 151, 5,
37, 0, 0, 151, 50, 1, 0, 0, 0, 152, 153, 5, 116, 0, 0, 153, 154, 5, 114,
0, 0, 154, 155, 5, 117, 0, 0, 155, 156, 5, 101, 0, 0, 156, 52, 1, 0, 0,
0, 157, 158, 5, 102, 0, 0, 158, 159, 5, 97, 0, 0, 159, 160, 5, 108, 0,
0, 160, 161, 5, 115, 0, 0, 161, 162, 5, 101, 0, 0, 162, 54, 1, 0, 0, 0,
163, 164, 5, 110, 0, 0, 164, 165, 5, 117, 0, 0, 165, 166, 5, 108, 0, 0,
166, 167, 5, 108, 0, 0, 167, 56, 1, 0, 0, 0, 168, 169, 5, 92, 0, 0, 169,
58, 1, 0, 0, 0, 170, 171, 7, 0, 0, 0, 171, 60, 1, 0, 0, 0, 172, 173, 2,
48, 57, 0, 173, 62, 1, 0, 0, 0, 174, 176, 7, 1, 0, 0, 175, 177, 7, 2, 0,
0, 176, 175, 1, 0, 0, 0, 176, 177, 1, 0, 0, 0, 177, 179, 1, 0, 0, 0, 178,
180, 3, 61, 30, 0, 179, 178, 1, 0, 0, 0, 180, 181, 1, 0, 0, 0, 181, 179,
1, 0, 0, 0, 181, 182, 1, 0, 0, 0, 182, 64, 1, 0, 0, 0, 183, 184, 7, 3,
0, 0, 184, 66, 1, 0, 0, 0, 185, 186, 7, 4, 0, 0, 186, 68, 1, 0, 0, 0, 187,
192, 3, 71, 35, 0, 188, 192, 3, 75, 37, 0, 189, 192, 3, 77, 38, 0, 190,
192, 3, 73, 36, 0, 191, 187, 1, 0, 0, 0, 191, 188, 1, 0, 0, 0, 191, 189,
1, 0, 0, 0, 191, 190, 1, 0, 0, 0, 192, 70, 1, 0, 0, 0, 193, 194, 3, 57,
28, 0, 194, 195, 7, 5, 0, 0, 195, 72, 1, 0, 0, 0, 196, 197, 3, 57, 28,
0, 197, 198, 2, 48, 51, 0, 198, 199, 2, 48, 55, 0, 199, 200, 2, 48, 55,
0, 200, 74, 1, 0, 0, 0, 201, 202, 3, 57, 28, 0, 202, 203, 7, 6, 0, 0, 203,
204, 3, 65, 32, 0, 204, 205, 3, 65, 32, 0, 205, 76, 1, 0, 0, 0, 206, 207,
3, 57, 28, 0, 207, 208, 5, 117, 0, 0, 208, 209, 3, 65, 32, 0, 209, 210,
3, 65, 32, 0, 210, 211, 3, 65, 32, 0, 211, 212, 3, 65, 32, 0, 212, 225,
1, 0, 0, 0, 213, 214, 3, 57, 28, 0, 214, 215, 5, 85, 0, 0, 215, 216, 3,
65, 32, 0, 216, 217, 3, 65, 32, 0, 217, 218, 3, 65, 32, 0, 218, 219, 3,
65, 32, 0, 219, 220, 3, 65, 32, 0, 220, 221, 3, 65, 32, 0, 221, 222, 3,
65, 32, 0, 222, 223, 3, 65, 32, 0, 223, 225, 1, 0, 0, 0, 224, 206, 1, 0,
0, 0, 224, 213, 1, 0, 0, 0, 225, 78, 1, 0, 0, 0, 226, 228, 7, 7, 0, 0,
227, 226, 1, 0, 0, 0, 228, 229, 1, 0, 0, 0, 229, 227, 1, 0, 0, 0, 229,
230, 1, 0, 0, 0, 230, 231, 1, 0, 0, 0, 231, 232, 6, 39, 0, 0, 232, 80,
1, 0, 0, 0, 233, 234, 5, 47, 0, 0, 234, 235, 5, 47, 0, 0, 235, 239, 1,
0, 0, 0, 236, 238, 8, 8, 0, 0, 237, 236, 1, 0, 0, 0, 238, 241, 1, 0, 0,
0, 239, 237, 1, 0, 0, 0, 239, 240, 1, 0, 0, 0, 240, 242, 1, 0, 0, 0, 241,
239, 1, 0, 0, 0, 242, 243, 6, 40, 0, 0, 243, 82, 1, 0, 0, 0, 244, 246,
3, 61, 30, 0, 245, 244, 1, 0, 0, 0, 246, 247, 1, 0, 0, 0, 247, 245, 1,
0, 0, 0, 247, 248, 1, 0, 0, 0, 248, 249, 1, 0, 0, 0, 249, 251, 5, 46, 0,
0, 250, 252, 3, 61, 30, 0, 251, 250, 1, 0, 0, 0, 252, 253, 1, 0, 0, 0,
253, 251, 1, 0, 0, 0, 253, 254, 1, 0, 0, 0, 254, 256, 1, 0, 0, 0, 255,
257, 3, 63, 31, 0, 256, 255, 1, 0, 0, 0, 256, 257, 1, 0, 0, 0, 257, 275,
1, 0, 0, 0, 258, 260, 3, 61, 30, 0, 259, 258, 1, 0, 0, 0, 260, 261, 1,
0, 0, 0, 261, 259, 1, 0, 0, 0, 261, 262, 1, 0, 0, 0, 262, 263, 1, 0, 0,
0, 263, 264, 3, 63, 31, 0, 264, 275, 1, 0, 0, 0, 265, 267, 5, 46, 0, 0,
266, 268, 3, 61, 30, 0, 267, 266, 1, 0, 0, 0, 268, 269, 1, 0, 0, 0, 269,
267, 1, 0, 0, 0, 269, 270, 1, 0, 0, 0, 270, 272, 1, 0, 0, 0, 271, 273,
3, 63, 31, 0, 272, 271, 1, 0, 0, 0, 272, 273, 1, 0, 0, 0, 273, 275, 1,
0, 0, 0, 274, 245, 1, 0, 0, 0, 274, 259, 1, 0, 0, 0, 274, 265, 1, 0, 0,
0, 275, 84, 1, 0, 0, 0, 276, 278, 3, 61, 30, 0, 277, 276, 1, 0, 0, 0, 278,
279, 1, 0, 0, 0, 279, 277, 1, 0, 0, 0, 279, 280, 1, 0, 0, 0, 280, 290,
1, 0, 0, 0, 281, 282, 5, 48, 0, 0, 282, 283, 5, 120, 0, 0, 283, 285, 1,
0, 0, 0, 284, 286, 3, 65, 32, 0, 285, 284, 1, 0, 0, 0, 286, 287, 1, 0,
0, 0, 287, 285, 1, 0, 0, 0, 287, 288, 1, 0, 0, 0, 288, 290, 1, 0, 0, 0,
289, 277, 1, 0, 0, 0, 289, 281, 1, 0, 0, 0, 290, 86, 1, 0, 0, 0, 291, 293,
3, 61, 30, 0, 292, 291, 1, 0, 0, 0, 293, 294, 1, 0, 0, 0, 294, 292, 1,
0, 0, 0, 294, 295, 1, 0, 0, 0, 295, 296, 1, 0, 0, 0, 296, 297, 7, 9, 0,
0, 297, 309, 1, 0, 0, 0, 298, 299, 5, 48, 0, 0, 299, 300, 5, 120, 0, 0,
300, 302, 1, 0, 0, 0, 301, 303, 3, 65, 32, 0, 302, 301, 1, 0, 0, 0, 303,
304, 1, 0, 0, 0, 304, 302, 1, 0, 0, 0, 304, 305, 1, 0, 0, 0, 305, 306,
1, 0, 0, 0, 306, 307, 7, 9, 0, 0, 307, 309, 1, 0, 0, 0, 308, 292, 1, 0,
0, 0, 308, 298, 1, 0, 0, 0, 309, 88, 1, 0, 0, 0, 310, 315, 5, 34, 0, 0,
311, 314, 3, 69, 34, 0, 312, 314, 8, 10, 0, 0, 313, 311, 1, 0, 0, 0, 313,
312, 1, 0, 0, 0, 314, 317, 1, 0, 0, 0, 315, 313, 1, 0, 0, 0, 315, 316,
1, 0, 0, 0, 316, 318, 1, 0, 0, 0, 317, 315, 1, 0, 0, 0, 318, 407, 5, 34,
0, 0, 319, 324, 5, 39, 0, 0, 320, 323, 3, 69, 34, 0, 321, 323, 8, 11, 0,
0, 322, 320, 1, 0, 0, 0, 322, 321, 1, 0, 0, 0, 323, 326, 1, 0, 0, 0, 324,
322, 1, 0, 0, 0, 324, 325, 1, 0, 0, 0, 325, 327, 1, 0, 0, 0, 326, 324,
1, 0, 0, 0, 327, 407, 5, 39, 0, 0, 328, 329, 5, 34, 0, 0, 329, 330, 5,
34, 0, 0, 330, 331, 5, 34, 0, 0, 331, 336, 1, 0, 0, 0, 332, 335, 3, 69,
34, 0, 333, 335, 8, 12, 0, 0, 334, 332, 1, 0, 0, 0, 334, 333, 1, 0, 0,
0, 335, 338, 1, 0, 0, 0, 336, 337, 1, 0, 0, 0, 336, 334, 1, 0, 0, 0, 337,
339, 1, 0, 0, 0, 338, 336, 1, 0, 0, 0, 339, 340, 5, 34, 0, 0, 340, 341,
5, 34, 0, 0, 341, 407, 5, 34, 0, 0, 342, 343, 5, 39, 0, 0, 343, 344, 5,
39, 0, 0, 344, 345, 5, 39, 0, 0, 345, 350, 1, 0, 0, 0, 346, 349, 3, 69,
34, 0, 347, 349, 8, 12, 0, 0, 348, 346, 1, 0, 0, 0, 348, 347, 1, 0, 0,
0, 349, 352, 1, 0, 0, 0, 350, 351, 1, 0, 0, 0, 350, 348, 1, 0, 0, 0, 351,
353, 1, 0, 0, 0, 352, 350, 1, 0, 0, 0, 353, 354, 5, 39, 0, 0, 354, 355,
5, 39, 0, 0, 355, 407, 5, 39, 0, 0, 356, 357, 3, 67, 33, 0, 357, 361, 5,
34, 0, 0, 358, 360, 8, 13, 0, 0, 359, 358, 1, 0, 0, 0, 360, 363, 1, 0,
0, 0, 361, 359, 1, 0, 0, 0, 361, 362, 1, 0, 0, 0, 362, 364, 1, 0, 0, 0,
363, 361, 1, 0, 0, 0, 364, 365, 5, 34, 0, 0, 365, 407, 1, 0, 0, 0, 366,
367, 3, 67, 33, 0, 367, 371, 5, 39, 0, 0, 368, 370, 8, 14, 0, 0, 369, 368,
1, 0, 0, 0, 370, 373, 1, 0, 0, 0, 371, 369, 1, 0, 0, 0, 371, 372, 1, 0,
0, 0, 372, 374, 1, 0, 0, 0, 373, 371, 1, 0, 0, 0, 374, 375, 5, 39, 0, 0,
375, 407, 1, 0, 0, 0, 376, 377, 3, 67, 33, 0, 377, 378, 5, 34, 0, 0, 378,
379, 5, 34, 0, 0, 379, 380, 5, 34, 0, 0, 380, 384, 1, 0, 0, 0, 381, 383,
9, 0, 0, 0, 382, 381, 1, 0, 0, 0, 383, 386, 1, 0, 0, 0, 384, 385, 1, 0,
0, 0, 384, 382, 1, 0, 0, 0, 385, 387, 1, 0, 0, 0, 386, 384, 1, 0, 0, 0,
387, 388, 5, 34, 0, 0, 388, 389, 5, 34, 0, 0, 389, 390, 5, 34, 0, 0, 390,
407, 1, 0, 0, 0, 391, 392, 3, 67, 33, 0, 392, 393, 5, 39, 0, 0, 393, 394,
5, 39, 0, 0, 394, 395, 5, 39, 0, 0, 395, 399, 1, 0, 0, 0, 396, 398, 9,
0, 0, 0, 397, 396, 1, 0, 0, 0, 398, 401, 1, 0, 0, 0, 399, 400, 1, 0, 0,
0, 399, 397, 1, 0, 0, 0, 400, 402, 1, 0, 0, 0, 401, 399, 1, 0, 0, 0, 402,
403, 5, 39, 0, 0, 403, 404, 5, 39, 0, 0, 404, 405, 5, 39, 0, 0, 405, 407,
1, 0, 0, 0, 406, 310, 1, 0, 0, 0, 406, 319, 1, 0, 0, 0, 406, 328, 1, 0,
0, 0, 406, 342, 1, 0, 0, 0, 406, 356, 1, 0, 0, 0, 406, 366, 1, 0, 0, 0,
406, 376, 1, 0, 0, 0, 406, 391, 1, 0, 0, 0, 407, 90, 1, 0, 0, 0, 408, 409,
7, 15, 0, 0, 409, 410, 3, 89, 44, 0, 410, 92, 1, 0, 0, 0, 411, 414, 3,
59, 29, 0, 412, 414, 5, 95, 0, 0, 413, 411, 1, 0, 0, 0, 413, 412, 1, 0,
0, 0, 414, 420, 1, 0, 0, 0, 415, 419, 3, 59, 29, 0, 416, 419, 3, 61, 30,
0, 417, 419, 5, 95, 0, 0, 418, 415, 1, 0, 0, 0, 418, 416, 1, 0, 0, 0, 418,
417, 1, 0, 0, 0, 419, 422, 1, 0, 0, 0, 420, 418, 1, 0, 0, 0, 420, 421,
1, 0, 0, 0, 421, 94, 1, 0, 0, 0, 422, 420, 1, 0, 0, 0, 36, 0, 176, 181,
191, 224, 229, 239, 247, 253, 256, 261, 269, 272, 274, 279, 287, 289, 294,
304, 308, 313, 315, 322, 324, 334, 336, 348, 350, 361, 371, 384, 399, 406,
413, 418, 420, 1, 0, 1, 0,
}
deserializer := antlr.NewATNDeserializer(nil)
staticData.atn = deserializer.Deserialize(staticData.serializedATN)
atn := staticData.atn
staticData.decisionToDFA = make([]*antlr.DFA, len(atn.DecisionToState))
decisionToDFA := staticData.decisionToDFA
for index, state := range atn.DecisionToState {
decisionToDFA[index] = antlr.NewDFA(state, index)
}
staticData := &CELLexerLexerStaticData
staticData.ChannelNames = []string{
"DEFAULT_TOKEN_CHANNEL", "HIDDEN",
}
staticData.ModeNames = []string{
"DEFAULT_MODE",
}
staticData.LiteralNames = []string{
"", "'=='", "'!='", "'in'", "'<'", "'<='", "'>='", "'>'", "'&&'", "'||'",
"'['", "']'", "'{'", "'}'", "'('", "')'", "'.'", "','", "'-'", "'!'",
"'?'", "':'", "'+'", "'*'", "'/'", "'%'", "'true'", "'false'", "'null'",
}
staticData.SymbolicNames = []string{
"", "EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE",
"NUL", "WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT",
"STRING", "BYTES", "IDENTIFIER",
}
staticData.RuleNames = []string{
"EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE",
"NUL", "BACKSLASH", "LETTER", "DIGIT", "EXPONENT", "HEXDIGIT", "RAW",
"ESC_SEQ", "ESC_CHAR_SEQ", "ESC_OCT_SEQ", "ESC_BYTE_SEQ", "ESC_UNI_SEQ",
"WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING",
"BYTES", "IDENTIFIER",
}
staticData.PredictionContextCache = antlr.NewPredictionContextCache()
staticData.serializedATN = []int32{
4, 0, 36, 423, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2,
4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2,
10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15,
7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 2, 20, 7,
20, 2, 21, 7, 21, 2, 22, 7, 22, 2, 23, 7, 23, 2, 24, 7, 24, 2, 25, 7, 25,
2, 26, 7, 26, 2, 27, 7, 27, 2, 28, 7, 28, 2, 29, 7, 29, 2, 30, 7, 30, 2,
31, 7, 31, 2, 32, 7, 32, 2, 33, 7, 33, 2, 34, 7, 34, 2, 35, 7, 35, 2, 36,
7, 36, 2, 37, 7, 37, 2, 38, 7, 38, 2, 39, 7, 39, 2, 40, 7, 40, 2, 41, 7,
41, 2, 42, 7, 42, 2, 43, 7, 43, 2, 44, 7, 44, 2, 45, 7, 45, 2, 46, 7, 46,
1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4,
1, 4, 1, 4, 1, 5, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 7, 1, 8, 1, 8,
1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13,
1, 14, 1, 14, 1, 15, 1, 15, 1, 16, 1, 16, 1, 17, 1, 17, 1, 18, 1, 18, 1,
19, 1, 19, 1, 20, 1, 20, 1, 21, 1, 21, 1, 22, 1, 22, 1, 23, 1, 23, 1, 24,
1, 24, 1, 25, 1, 25, 1, 25, 1, 25, 1, 25, 1, 26, 1, 26, 1, 26, 1, 26, 1,
26, 1, 26, 1, 27, 1, 27, 1, 27, 1, 27, 1, 27, 1, 28, 1, 28, 1, 29, 1, 29,
1, 30, 1, 30, 1, 31, 1, 31, 3, 31, 177, 8, 31, 1, 31, 4, 31, 180, 8, 31,
11, 31, 12, 31, 181, 1, 32, 1, 32, 1, 33, 1, 33, 1, 34, 1, 34, 1, 34, 1,
34, 3, 34, 192, 8, 34, 1, 35, 1, 35, 1, 35, 1, 36, 1, 36, 1, 36, 1, 36,
1, 36, 1, 37, 1, 37, 1, 37, 1, 37, 1, 37, 1, 38, 1, 38, 1, 38, 1, 38, 1,
38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38,
1, 38, 1, 38, 1, 38, 3, 38, 225, 8, 38, 1, 39, 4, 39, 228, 8, 39, 11, 39,
12, 39, 229, 1, 39, 1, 39, 1, 40, 1, 40, 1, 40, 1, 40, 5, 40, 238, 8, 40,
10, 40, 12, 40, 241, 9, 40, 1, 40, 1, 40, 1, 41, 4, 41, 246, 8, 41, 11,
41, 12, 41, 247, 1, 41, 1, 41, 4, 41, 252, 8, 41, 11, 41, 12, 41, 253,
1, 41, 3, 41, 257, 8, 41, 1, 41, 4, 41, 260, 8, 41, 11, 41, 12, 41, 261,
1, 41, 1, 41, 1, 41, 1, 41, 4, 41, 268, 8, 41, 11, 41, 12, 41, 269, 1,
41, 3, 41, 273, 8, 41, 3, 41, 275, 8, 41, 1, 42, 4, 42, 278, 8, 42, 11,
42, 12, 42, 279, 1, 42, 1, 42, 1, 42, 1, 42, 4, 42, 286, 8, 42, 11, 42,
12, 42, 287, 3, 42, 290, 8, 42, 1, 43, 4, 43, 293, 8, 43, 11, 43, 12, 43,
294, 1, 43, 1, 43, 1, 43, 1, 43, 1, 43, 1, 43, 4, 43, 303, 8, 43, 11, 43,
12, 43, 304, 1, 43, 1, 43, 3, 43, 309, 8, 43, 1, 44, 1, 44, 1, 44, 5, 44,
314, 8, 44, 10, 44, 12, 44, 317, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5,
44, 323, 8, 44, 10, 44, 12, 44, 326, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44,
1, 44, 1, 44, 1, 44, 5, 44, 335, 8, 44, 10, 44, 12, 44, 338, 9, 44, 1,
44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5, 44, 349,
8, 44, 10, 44, 12, 44, 352, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1,
44, 5, 44, 360, 8, 44, 10, 44, 12, 44, 363, 9, 44, 1, 44, 1, 44, 1, 44,
1, 44, 1, 44, 5, 44, 370, 8, 44, 10, 44, 12, 44, 373, 9, 44, 1, 44, 1,
44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 5, 44, 383, 8, 44, 10, 44,
12, 44, 386, 9, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1, 44, 1,
44, 1, 44, 1, 44, 5, 44, 398, 8, 44, 10, 44, 12, 44, 401, 9, 44, 1, 44,
1, 44, 1, 44, 1, 44, 3, 44, 407, 8, 44, 1, 45, 1, 45, 1, 45, 1, 46, 1,
46, 3, 46, 414, 8, 46, 1, 46, 1, 46, 1, 46, 5, 46, 419, 8, 46, 10, 46,
12, 46, 422, 9, 46, 4, 336, 350, 384, 399, 0, 47, 1, 1, 3, 2, 5, 3, 7,
4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 12, 25, 13, 27,
14, 29, 15, 31, 16, 33, 17, 35, 18, 37, 19, 39, 20, 41, 21, 43, 22, 45,
23, 47, 24, 49, 25, 51, 26, 53, 27, 55, 28, 57, 0, 59, 0, 61, 0, 63, 0,
65, 0, 67, 0, 69, 0, 71, 0, 73, 0, 75, 0, 77, 0, 79, 29, 81, 30, 83, 31,
85, 32, 87, 33, 89, 34, 91, 35, 93, 36, 1, 0, 16, 2, 0, 65, 90, 97, 122,
2, 0, 69, 69, 101, 101, 2, 0, 43, 43, 45, 45, 3, 0, 48, 57, 65, 70, 97,
102, 2, 0, 82, 82, 114, 114, 10, 0, 34, 34, 39, 39, 63, 63, 92, 92, 96,
98, 102, 102, 110, 110, 114, 114, 116, 116, 118, 118, 2, 0, 88, 88, 120,
120, 3, 0, 9, 10, 12, 13, 32, 32, 1, 0, 10, 10, 2, 0, 85, 85, 117, 117,
4, 0, 10, 10, 13, 13, 34, 34, 92, 92, 4, 0, 10, 10, 13, 13, 39, 39, 92,
92, 1, 0, 92, 92, 3, 0, 10, 10, 13, 13, 34, 34, 3, 0, 10, 10, 13, 13, 39,
39, 2, 0, 66, 66, 98, 98, 456, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5,
1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13,
1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0,
21, 1, 0, 0, 0, 0, 23, 1, 0, 0, 0, 0, 25, 1, 0, 0, 0, 0, 27, 1, 0, 0, 0,
0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0,
0, 0, 37, 1, 0, 0, 0, 0, 39, 1, 0, 0, 0, 0, 41, 1, 0, 0, 0, 0, 43, 1, 0,
0, 0, 0, 45, 1, 0, 0, 0, 0, 47, 1, 0, 0, 0, 0, 49, 1, 0, 0, 0, 0, 51, 1,
0, 0, 0, 0, 53, 1, 0, 0, 0, 0, 55, 1, 0, 0, 0, 0, 79, 1, 0, 0, 0, 0, 81,
1, 0, 0, 0, 0, 83, 1, 0, 0, 0, 0, 85, 1, 0, 0, 0, 0, 87, 1, 0, 0, 0, 0,
89, 1, 0, 0, 0, 0, 91, 1, 0, 0, 0, 0, 93, 1, 0, 0, 0, 1, 95, 1, 0, 0, 0,
3, 98, 1, 0, 0, 0, 5, 101, 1, 0, 0, 0, 7, 104, 1, 0, 0, 0, 9, 106, 1, 0,
0, 0, 11, 109, 1, 0, 0, 0, 13, 112, 1, 0, 0, 0, 15, 114, 1, 0, 0, 0, 17,
117, 1, 0, 0, 0, 19, 120, 1, 0, 0, 0, 21, 122, 1, 0, 0, 0, 23, 124, 1,
0, 0, 0, 25, 126, 1, 0, 0, 0, 27, 128, 1, 0, 0, 0, 29, 130, 1, 0, 0, 0,
31, 132, 1, 0, 0, 0, 33, 134, 1, 0, 0, 0, 35, 136, 1, 0, 0, 0, 37, 138,
1, 0, 0, 0, 39, 140, 1, 0, 0, 0, 41, 142, 1, 0, 0, 0, 43, 144, 1, 0, 0,
0, 45, 146, 1, 0, 0, 0, 47, 148, 1, 0, 0, 0, 49, 150, 1, 0, 0, 0, 51, 152,
1, 0, 0, 0, 53, 157, 1, 0, 0, 0, 55, 163, 1, 0, 0, 0, 57, 168, 1, 0, 0,
0, 59, 170, 1, 0, 0, 0, 61, 172, 1, 0, 0, 0, 63, 174, 1, 0, 0, 0, 65, 183,
1, 0, 0, 0, 67, 185, 1, 0, 0, 0, 69, 191, 1, 0, 0, 0, 71, 193, 1, 0, 0,
0, 73, 196, 1, 0, 0, 0, 75, 201, 1, 0, 0, 0, 77, 224, 1, 0, 0, 0, 79, 227,
1, 0, 0, 0, 81, 233, 1, 0, 0, 0, 83, 274, 1, 0, 0, 0, 85, 289, 1, 0, 0,
0, 87, 308, 1, 0, 0, 0, 89, 406, 1, 0, 0, 0, 91, 408, 1, 0, 0, 0, 93, 413,
1, 0, 0, 0, 95, 96, 5, 61, 0, 0, 96, 97, 5, 61, 0, 0, 97, 2, 1, 0, 0, 0,
98, 99, 5, 33, 0, 0, 99, 100, 5, 61, 0, 0, 100, 4, 1, 0, 0, 0, 101, 102,
5, 105, 0, 0, 102, 103, 5, 110, 0, 0, 103, 6, 1, 0, 0, 0, 104, 105, 5,
60, 0, 0, 105, 8, 1, 0, 0, 0, 106, 107, 5, 60, 0, 0, 107, 108, 5, 61, 0,
0, 108, 10, 1, 0, 0, 0, 109, 110, 5, 62, 0, 0, 110, 111, 5, 61, 0, 0, 111,
12, 1, 0, 0, 0, 112, 113, 5, 62, 0, 0, 113, 14, 1, 0, 0, 0, 114, 115, 5,
38, 0, 0, 115, 116, 5, 38, 0, 0, 116, 16, 1, 0, 0, 0, 117, 118, 5, 124,
0, 0, 118, 119, 5, 124, 0, 0, 119, 18, 1, 0, 0, 0, 120, 121, 5, 91, 0,
0, 121, 20, 1, 0, 0, 0, 122, 123, 5, 93, 0, 0, 123, 22, 1, 0, 0, 0, 124,
125, 5, 123, 0, 0, 125, 24, 1, 0, 0, 0, 126, 127, 5, 125, 0, 0, 127, 26,
1, 0, 0, 0, 128, 129, 5, 40, 0, 0, 129, 28, 1, 0, 0, 0, 130, 131, 5, 41,
0, 0, 131, 30, 1, 0, 0, 0, 132, 133, 5, 46, 0, 0, 133, 32, 1, 0, 0, 0,
134, 135, 5, 44, 0, 0, 135, 34, 1, 0, 0, 0, 136, 137, 5, 45, 0, 0, 137,
36, 1, 0, 0, 0, 138, 139, 5, 33, 0, 0, 139, 38, 1, 0, 0, 0, 140, 141, 5,
63, 0, 0, 141, 40, 1, 0, 0, 0, 142, 143, 5, 58, 0, 0, 143, 42, 1, 0, 0,
0, 144, 145, 5, 43, 0, 0, 145, 44, 1, 0, 0, 0, 146, 147, 5, 42, 0, 0, 147,
46, 1, 0, 0, 0, 148, 149, 5, 47, 0, 0, 149, 48, 1, 0, 0, 0, 150, 151, 5,
37, 0, 0, 151, 50, 1, 0, 0, 0, 152, 153, 5, 116, 0, 0, 153, 154, 5, 114,
0, 0, 154, 155, 5, 117, 0, 0, 155, 156, 5, 101, 0, 0, 156, 52, 1, 0, 0,
0, 157, 158, 5, 102, 0, 0, 158, 159, 5, 97, 0, 0, 159, 160, 5, 108, 0,
0, 160, 161, 5, 115, 0, 0, 161, 162, 5, 101, 0, 0, 162, 54, 1, 0, 0, 0,
163, 164, 5, 110, 0, 0, 164, 165, 5, 117, 0, 0, 165, 166, 5, 108, 0, 0,
166, 167, 5, 108, 0, 0, 167, 56, 1, 0, 0, 0, 168, 169, 5, 92, 0, 0, 169,
58, 1, 0, 0, 0, 170, 171, 7, 0, 0, 0, 171, 60, 1, 0, 0, 0, 172, 173, 2,
48, 57, 0, 173, 62, 1, 0, 0, 0, 174, 176, 7, 1, 0, 0, 175, 177, 7, 2, 0,
0, 176, 175, 1, 0, 0, 0, 176, 177, 1, 0, 0, 0, 177, 179, 1, 0, 0, 0, 178,
180, 3, 61, 30, 0, 179, 178, 1, 0, 0, 0, 180, 181, 1, 0, 0, 0, 181, 179,
1, 0, 0, 0, 181, 182, 1, 0, 0, 0, 182, 64, 1, 0, 0, 0, 183, 184, 7, 3,
0, 0, 184, 66, 1, 0, 0, 0, 185, 186, 7, 4, 0, 0, 186, 68, 1, 0, 0, 0, 187,
192, 3, 71, 35, 0, 188, 192, 3, 75, 37, 0, 189, 192, 3, 77, 38, 0, 190,
192, 3, 73, 36, 0, 191, 187, 1, 0, 0, 0, 191, 188, 1, 0, 0, 0, 191, 189,
1, 0, 0, 0, 191, 190, 1, 0, 0, 0, 192, 70, 1, 0, 0, 0, 193, 194, 3, 57,
28, 0, 194, 195, 7, 5, 0, 0, 195, 72, 1, 0, 0, 0, 196, 197, 3, 57, 28,
0, 197, 198, 2, 48, 51, 0, 198, 199, 2, 48, 55, 0, 199, 200, 2, 48, 55,
0, 200, 74, 1, 0, 0, 0, 201, 202, 3, 57, 28, 0, 202, 203, 7, 6, 0, 0, 203,
204, 3, 65, 32, 0, 204, 205, 3, 65, 32, 0, 205, 76, 1, 0, 0, 0, 206, 207,
3, 57, 28, 0, 207, 208, 5, 117, 0, 0, 208, 209, 3, 65, 32, 0, 209, 210,
3, 65, 32, 0, 210, 211, 3, 65, 32, 0, 211, 212, 3, 65, 32, 0, 212, 225,
1, 0, 0, 0, 213, 214, 3, 57, 28, 0, 214, 215, 5, 85, 0, 0, 215, 216, 3,
65, 32, 0, 216, 217, 3, 65, 32, 0, 217, 218, 3, 65, 32, 0, 218, 219, 3,
65, 32, 0, 219, 220, 3, 65, 32, 0, 220, 221, 3, 65, 32, 0, 221, 222, 3,
65, 32, 0, 222, 223, 3, 65, 32, 0, 223, 225, 1, 0, 0, 0, 224, 206, 1, 0,
0, 0, 224, 213, 1, 0, 0, 0, 225, 78, 1, 0, 0, 0, 226, 228, 7, 7, 0, 0,
227, 226, 1, 0, 0, 0, 228, 229, 1, 0, 0, 0, 229, 227, 1, 0, 0, 0, 229,
230, 1, 0, 0, 0, 230, 231, 1, 0, 0, 0, 231, 232, 6, 39, 0, 0, 232, 80,
1, 0, 0, 0, 233, 234, 5, 47, 0, 0, 234, 235, 5, 47, 0, 0, 235, 239, 1,
0, 0, 0, 236, 238, 8, 8, 0, 0, 237, 236, 1, 0, 0, 0, 238, 241, 1, 0, 0,
0, 239, 237, 1, 0, 0, 0, 239, 240, 1, 0, 0, 0, 240, 242, 1, 0, 0, 0, 241,
239, 1, 0, 0, 0, 242, 243, 6, 40, 0, 0, 243, 82, 1, 0, 0, 0, 244, 246,
3, 61, 30, 0, 245, 244, 1, 0, 0, 0, 246, 247, 1, 0, 0, 0, 247, 245, 1,
0, 0, 0, 247, 248, 1, 0, 0, 0, 248, 249, 1, 0, 0, 0, 249, 251, 5, 46, 0,
0, 250, 252, 3, 61, 30, 0, 251, 250, 1, 0, 0, 0, 252, 253, 1, 0, 0, 0,
253, 251, 1, 0, 0, 0, 253, 254, 1, 0, 0, 0, 254, 256, 1, 0, 0, 0, 255,
257, 3, 63, 31, 0, 256, 255, 1, 0, 0, 0, 256, 257, 1, 0, 0, 0, 257, 275,
1, 0, 0, 0, 258, 260, 3, 61, 30, 0, 259, 258, 1, 0, 0, 0, 260, 261, 1,
0, 0, 0, 261, 259, 1, 0, 0, 0, 261, 262, 1, 0, 0, 0, 262, 263, 1, 0, 0,
0, 263, 264, 3, 63, 31, 0, 264, 275, 1, 0, 0, 0, 265, 267, 5, 46, 0, 0,
266, 268, 3, 61, 30, 0, 267, 266, 1, 0, 0, 0, 268, 269, 1, 0, 0, 0, 269,
267, 1, 0, 0, 0, 269, 270, 1, 0, 0, 0, 270, 272, 1, 0, 0, 0, 271, 273,
3, 63, 31, 0, 272, 271, 1, 0, 0, 0, 272, 273, 1, 0, 0, 0, 273, 275, 1,
0, 0, 0, 274, 245, 1, 0, 0, 0, 274, 259, 1, 0, 0, 0, 274, 265, 1, 0, 0,
0, 275, 84, 1, 0, 0, 0, 276, 278, 3, 61, 30, 0, 277, 276, 1, 0, 0, 0, 278,
279, 1, 0, 0, 0, 279, 277, 1, 0, 0, 0, 279, 280, 1, 0, 0, 0, 280, 290,
1, 0, 0, 0, 281, 282, 5, 48, 0, 0, 282, 283, 5, 120, 0, 0, 283, 285, 1,
0, 0, 0, 284, 286, 3, 65, 32, 0, 285, 284, 1, 0, 0, 0, 286, 287, 1, 0,
0, 0, 287, 285, 1, 0, 0, 0, 287, 288, 1, 0, 0, 0, 288, 290, 1, 0, 0, 0,
289, 277, 1, 0, 0, 0, 289, 281, 1, 0, 0, 0, 290, 86, 1, 0, 0, 0, 291, 293,
3, 61, 30, 0, 292, 291, 1, 0, 0, 0, 293, 294, 1, 0, 0, 0, 294, 292, 1,
0, 0, 0, 294, 295, 1, 0, 0, 0, 295, 296, 1, 0, 0, 0, 296, 297, 7, 9, 0,
0, 297, 309, 1, 0, 0, 0, 298, 299, 5, 48, 0, 0, 299, 300, 5, 120, 0, 0,
300, 302, 1, 0, 0, 0, 301, 303, 3, 65, 32, 0, 302, 301, 1, 0, 0, 0, 303,
304, 1, 0, 0, 0, 304, 302, 1, 0, 0, 0, 304, 305, 1, 0, 0, 0, 305, 306,
1, 0, 0, 0, 306, 307, 7, 9, 0, 0, 307, 309, 1, 0, 0, 0, 308, 292, 1, 0,
0, 0, 308, 298, 1, 0, 0, 0, 309, 88, 1, 0, 0, 0, 310, 315, 5, 34, 0, 0,
311, 314, 3, 69, 34, 0, 312, 314, 8, 10, 0, 0, 313, 311, 1, 0, 0, 0, 313,
312, 1, 0, 0, 0, 314, 317, 1, 0, 0, 0, 315, 313, 1, 0, 0, 0, 315, 316,
1, 0, 0, 0, 316, 318, 1, 0, 0, 0, 317, 315, 1, 0, 0, 0, 318, 407, 5, 34,
0, 0, 319, 324, 5, 39, 0, 0, 320, 323, 3, 69, 34, 0, 321, 323, 8, 11, 0,
0, 322, 320, 1, 0, 0, 0, 322, 321, 1, 0, 0, 0, 323, 326, 1, 0, 0, 0, 324,
322, 1, 0, 0, 0, 324, 325, 1, 0, 0, 0, 325, 327, 1, 0, 0, 0, 326, 324,
1, 0, 0, 0, 327, 407, 5, 39, 0, 0, 328, 329, 5, 34, 0, 0, 329, 330, 5,
34, 0, 0, 330, 331, 5, 34, 0, 0, 331, 336, 1, 0, 0, 0, 332, 335, 3, 69,
34, 0, 333, 335, 8, 12, 0, 0, 334, 332, 1, 0, 0, 0, 334, 333, 1, 0, 0,
0, 335, 338, 1, 0, 0, 0, 336, 337, 1, 0, 0, 0, 336, 334, 1, 0, 0, 0, 337,
339, 1, 0, 0, 0, 338, 336, 1, 0, 0, 0, 339, 340, 5, 34, 0, 0, 340, 341,
5, 34, 0, 0, 341, 407, 5, 34, 0, 0, 342, 343, 5, 39, 0, 0, 343, 344, 5,
39, 0, 0, 344, 345, 5, 39, 0, 0, 345, 350, 1, 0, 0, 0, 346, 349, 3, 69,
34, 0, 347, 349, 8, 12, 0, 0, 348, 346, 1, 0, 0, 0, 348, 347, 1, 0, 0,
0, 349, 352, 1, 0, 0, 0, 350, 351, 1, 0, 0, 0, 350, 348, 1, 0, 0, 0, 351,
353, 1, 0, 0, 0, 352, 350, 1, 0, 0, 0, 353, 354, 5, 39, 0, 0, 354, 355,
5, 39, 0, 0, 355, 407, 5, 39, 0, 0, 356, 357, 3, 67, 33, 0, 357, 361, 5,
34, 0, 0, 358, 360, 8, 13, 0, 0, 359, 358, 1, 0, 0, 0, 360, 363, 1, 0,
0, 0, 361, 359, 1, 0, 0, 0, 361, 362, 1, 0, 0, 0, 362, 364, 1, 0, 0, 0,
363, 361, 1, 0, 0, 0, 364, 365, 5, 34, 0, 0, 365, 407, 1, 0, 0, 0, 366,
367, 3, 67, 33, 0, 367, 371, 5, 39, 0, 0, 368, 370, 8, 14, 0, 0, 369, 368,
1, 0, 0, 0, 370, 373, 1, 0, 0, 0, 371, 369, 1, 0, 0, 0, 371, 372, 1, 0,
0, 0, 372, 374, 1, 0, 0, 0, 373, 371, 1, 0, 0, 0, 374, 375, 5, 39, 0, 0,
375, 407, 1, 0, 0, 0, 376, 377, 3, 67, 33, 0, 377, 378, 5, 34, 0, 0, 378,
379, 5, 34, 0, 0, 379, 380, 5, 34, 0, 0, 380, 384, 1, 0, 0, 0, 381, 383,
9, 0, 0, 0, 382, 381, 1, 0, 0, 0, 383, 386, 1, 0, 0, 0, 384, 385, 1, 0,
0, 0, 384, 382, 1, 0, 0, 0, 385, 387, 1, 0, 0, 0, 386, 384, 1, 0, 0, 0,
387, 388, 5, 34, 0, 0, 388, 389, 5, 34, 0, 0, 389, 390, 5, 34, 0, 0, 390,
407, 1, 0, 0, 0, 391, 392, 3, 67, 33, 0, 392, 393, 5, 39, 0, 0, 393, 394,
5, 39, 0, 0, 394, 395, 5, 39, 0, 0, 395, 399, 1, 0, 0, 0, 396, 398, 9,
0, 0, 0, 397, 396, 1, 0, 0, 0, 398, 401, 1, 0, 0, 0, 399, 400, 1, 0, 0,
0, 399, 397, 1, 0, 0, 0, 400, 402, 1, 0, 0, 0, 401, 399, 1, 0, 0, 0, 402,
403, 5, 39, 0, 0, 403, 404, 5, 39, 0, 0, 404, 405, 5, 39, 0, 0, 405, 407,
1, 0, 0, 0, 406, 310, 1, 0, 0, 0, 406, 319, 1, 0, 0, 0, 406, 328, 1, 0,
0, 0, 406, 342, 1, 0, 0, 0, 406, 356, 1, 0, 0, 0, 406, 366, 1, 0, 0, 0,
406, 376, 1, 0, 0, 0, 406, 391, 1, 0, 0, 0, 407, 90, 1, 0, 0, 0, 408, 409,
7, 15, 0, 0, 409, 410, 3, 89, 44, 0, 410, 92, 1, 0, 0, 0, 411, 414, 3,
59, 29, 0, 412, 414, 5, 95, 0, 0, 413, 411, 1, 0, 0, 0, 413, 412, 1, 0,
0, 0, 414, 420, 1, 0, 0, 0, 415, 419, 3, 59, 29, 0, 416, 419, 3, 61, 30,
0, 417, 419, 5, 95, 0, 0, 418, 415, 1, 0, 0, 0, 418, 416, 1, 0, 0, 0, 418,
417, 1, 0, 0, 0, 419, 422, 1, 0, 0, 0, 420, 418, 1, 0, 0, 0, 420, 421,
1, 0, 0, 0, 421, 94, 1, 0, 0, 0, 422, 420, 1, 0, 0, 0, 36, 0, 176, 181,
191, 224, 229, 239, 247, 253, 256, 261, 269, 272, 274, 279, 287, 289, 294,
304, 308, 313, 315, 322, 324, 334, 336, 348, 350, 361, 371, 384, 399, 406,
413, 418, 420, 1, 0, 1, 0,
}
deserializer := antlr.NewATNDeserializer(nil)
staticData.atn = deserializer.Deserialize(staticData.serializedATN)
atn := staticData.atn
staticData.decisionToDFA = make([]*antlr.DFA, len(atn.DecisionToState))
decisionToDFA := staticData.decisionToDFA
for index, state := range atn.DecisionToState {
decisionToDFA[index] = antlr.NewDFA(state, index)
}
}
// CELLexerInit initializes any static state used to implement CELLexer. By default the
@ -282,22 +280,22 @@ func cellexerLexerInit() {
// NewCELLexer(). You can call this function if you wish to initialize the static state ahead
// of time.
func CELLexerInit() {
staticData := &cellexerLexerStaticData
staticData.once.Do(cellexerLexerInit)
staticData := &CELLexerLexerStaticData
staticData.once.Do(cellexerLexerInit)
}
// NewCELLexer produces a new lexer instance for the optional input antlr.CharStream.
func NewCELLexer(input antlr.CharStream) *CELLexer {
CELLexerInit()
CELLexerInit()
l := new(CELLexer)
l.BaseLexer = antlr.NewBaseLexer(input)
staticData := &cellexerLexerStaticData
l.Interpreter = antlr.NewLexerATNSimulator(l, staticData.atn, staticData.decisionToDFA, staticData.predictionContextCache)
l.channelNames = staticData.channelNames
l.modeNames = staticData.modeNames
l.RuleNames = staticData.ruleNames
l.LiteralNames = staticData.literalNames
l.SymbolicNames = staticData.symbolicNames
staticData := &CELLexerLexerStaticData
l.Interpreter = antlr.NewLexerATNSimulator(l, staticData.atn, staticData.decisionToDFA, staticData.PredictionContextCache)
l.channelNames = staticData.ChannelNames
l.modeNames = staticData.ModeNames
l.RuleNames = staticData.RuleNames
l.LiteralNames = staticData.LiteralNames
l.SymbolicNames = staticData.SymbolicNames
l.GrammarFileName = "CEL.g4"
// TODO: l.EOF = antlr.TokenEOF
@ -306,40 +304,41 @@ func NewCELLexer(input antlr.CharStream) *CELLexer {
// CELLexer tokens.
const (
CELLexerEQUALS = 1
CELLexerNOT_EQUALS = 2
CELLexerIN = 3
CELLexerLESS = 4
CELLexerLESS_EQUALS = 5
CELLexerEQUALS = 1
CELLexerNOT_EQUALS = 2
CELLexerIN = 3
CELLexerLESS = 4
CELLexerLESS_EQUALS = 5
CELLexerGREATER_EQUALS = 6
CELLexerGREATER = 7
CELLexerLOGICAL_AND = 8
CELLexerLOGICAL_OR = 9
CELLexerLBRACKET = 10
CELLexerRPRACKET = 11
CELLexerLBRACE = 12
CELLexerRBRACE = 13
CELLexerLPAREN = 14
CELLexerRPAREN = 15
CELLexerDOT = 16
CELLexerCOMMA = 17
CELLexerMINUS = 18
CELLexerEXCLAM = 19
CELLexerQUESTIONMARK = 20
CELLexerCOLON = 21
CELLexerPLUS = 22
CELLexerSTAR = 23
CELLexerSLASH = 24
CELLexerPERCENT = 25
CELLexerCEL_TRUE = 26
CELLexerCEL_FALSE = 27
CELLexerNUL = 28
CELLexerWHITESPACE = 29
CELLexerCOMMENT = 30
CELLexerNUM_FLOAT = 31
CELLexerNUM_INT = 32
CELLexerNUM_UINT = 33
CELLexerSTRING = 34
CELLexerBYTES = 35
CELLexerIDENTIFIER = 36
CELLexerGREATER = 7
CELLexerLOGICAL_AND = 8
CELLexerLOGICAL_OR = 9
CELLexerLBRACKET = 10
CELLexerRPRACKET = 11
CELLexerLBRACE = 12
CELLexerRBRACE = 13
CELLexerLPAREN = 14
CELLexerRPAREN = 15
CELLexerDOT = 16
CELLexerCOMMA = 17
CELLexerMINUS = 18
CELLexerEXCLAM = 19
CELLexerQUESTIONMARK = 20
CELLexerCOLON = 21
CELLexerPLUS = 22
CELLexerSTAR = 23
CELLexerSLASH = 24
CELLexerPERCENT = 25
CELLexerCEL_TRUE = 26
CELLexerCEL_FALSE = 27
CELLexerNUL = 28
CELLexerWHITESPACE = 29
CELLexerCOMMENT = 30
CELLexerNUM_FLOAT = 31
CELLexerNUM_INT = 32
CELLexerNUM_UINT = 33
CELLexerSTRING = 34
CELLexerBYTES = 35
CELLexerIDENTIFIER = 36
)

View File

@ -1,7 +1,8 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
// Code generated from /usr/local/google/home/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.13.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
import "github.com/antlr4-go/antlr/v4"
// CELListener is a complete listener for a parse tree produced by CELParser.
type CELListener interface {

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.12.0. DO NOT EDIT.
// Code generated from /usr/local/google/home/tswadell/go/src/github.com/google/cel-go/parser/gen/CEL.g4 by ANTLR 4.13.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr/v4"
import "github.com/antlr4-go/antlr/v4"
// A complete Visitor for a parse tree produced by CELParser.
type CELVisitor interface {
@ -105,4 +106,5 @@ type CELVisitor interface {
// Visit a parse tree produced by CELParser#Null.
VisitNull(ctx *NullContext) interface{}
}
}

View File

@ -27,7 +27,7 @@
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Generate AntLR artifacts.
java -Xmx500M -cp ${DIR}/antlr-4.12.0-complete.jar org.antlr.v4.Tool \
java -Xmx500M -cp ${DIR}/antlr-4.13.1-complete.jar org.antlr.v4.Tool \
-Dlanguage=Go \
-package gen \
-o ${DIR} \

View File

@ -17,284 +17,209 @@ package parser
import (
"sync"
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
antlr "github.com/antlr4-go/antlr/v4"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type parserHelper struct {
source common.Source
nextID int64
positions map[int64]int32
macroCalls map[int64]*exprpb.Expr
exprFactory ast.ExprFactory
source common.Source
sourceInfo *ast.SourceInfo
nextID int64
}
func newParserHelper(source common.Source) *parserHelper {
func newParserHelper(source common.Source, fac ast.ExprFactory) *parserHelper {
return &parserHelper{
source: source,
nextID: 1,
positions: make(map[int64]int32),
macroCalls: make(map[int64]*exprpb.Expr),
exprFactory: fac,
source: source,
sourceInfo: ast.NewSourceInfo(source),
nextID: 1,
}
}
func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo {
return &exprpb.SourceInfo{
Location: p.source.Description(),
Positions: p.positions,
LineOffsets: p.source.LineOffsets(),
MacroCalls: p.macroCalls}
func (p *parserHelper) getSourceInfo() *ast.SourceInfo {
return p.sourceInfo
}
func (p *parserHelper) newLiteral(ctx any, value *exprpb.Constant) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: value}
return exprNode
func (p *parserHelper) newLiteral(ctx any, value ref.Val) ast.Expr {
return p.exprFactory.NewLiteral(p.newID(ctx), value)
}
func (p *parserHelper) newLiteralBool(ctx any, value bool) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: value}})
func (p *parserHelper) newLiteralBool(ctx any, value bool) ast.Expr {
return p.newLiteral(ctx, types.Bool(value))
}
func (p *parserHelper) newLiteralString(ctx any, value string) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: value}})
func (p *parserHelper) newLiteralString(ctx any, value string) ast.Expr {
return p.newLiteral(ctx, types.String(value))
}
func (p *parserHelper) newLiteralBytes(ctx any, value []byte) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: value}})
func (p *parserHelper) newLiteralBytes(ctx any, value []byte) ast.Expr {
return p.newLiteral(ctx, types.Bytes(value))
}
func (p *parserHelper) newLiteralInt(ctx any, value int64) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: value}})
func (p *parserHelper) newLiteralInt(ctx any, value int64) ast.Expr {
return p.newLiteral(ctx, types.Int(value))
}
func (p *parserHelper) newLiteralUint(ctx any, value uint64) *exprpb.Expr {
return p.newLiteral(ctx, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: value}})
func (p *parserHelper) newLiteralUint(ctx any, value uint64) ast.Expr {
return p.newLiteral(ctx, types.Uint(value))
}
func (p *parserHelper) newLiteralDouble(ctx any, value float64) *exprpb.Expr {
return p.newLiteral(ctx,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: value}})
func (p *parserHelper) newLiteralDouble(ctx any, value float64) ast.Expr {
return p.newLiteral(ctx, types.Double(value))
}
func (p *parserHelper) newIdent(ctx any, name string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: &exprpb.Expr_Ident{Name: name}}
return exprNode
func (p *parserHelper) newIdent(ctx any, name string) ast.Expr {
return p.exprFactory.NewIdent(p.newID(ctx), name)
}
func (p *parserHelper) newSelect(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field}}
return exprNode
func (p *parserHelper) newSelect(ctx any, operand ast.Expr, field string) ast.Expr {
return p.exprFactory.NewSelect(p.newID(ctx), operand, field)
}
func (p *parserHelper) newPresenceTest(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field, TestOnly: true}}
return exprNode
func (p *parserHelper) newPresenceTest(ctx any, operand ast.Expr, field string) ast.Expr {
return p.exprFactory.NewPresenceTest(p.newID(ctx), operand, field)
}
func (p *parserHelper) newGlobalCall(ctx any, function string, args ...*exprpb.Expr) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{Function: function, Args: args}}
return exprNode
func (p *parserHelper) newGlobalCall(ctx any, function string, args ...ast.Expr) ast.Expr {
return p.exprFactory.NewCall(p.newID(ctx), function, args...)
}
func (p *parserHelper) newReceiverCall(ctx any, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{Function: function, Target: target, Args: args}}
return exprNode
func (p *parserHelper) newReceiverCall(ctx any, function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return p.exprFactory.NewMemberCall(p.newID(ctx), function, target, args...)
}
func (p *parserHelper) newList(ctx any, elements []*exprpb.Expr, optionals ...int32) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elements,
OptionalIndices: optionals,
}}
return exprNode
func (p *parserHelper) newList(ctx any, elements []ast.Expr, optionals ...int32) ast.Expr {
return p.exprFactory.NewList(p.newID(ctx), elements, optionals)
}
func (p *parserHelper) newMap(ctx any, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{Entries: entries}}
return exprNode
func (p *parserHelper) newMap(ctx any, entries ...ast.EntryExpr) ast.Expr {
return p.exprFactory.NewMap(p.newID(ctx), entries)
}
func (p *parserHelper) newMapEntry(entryID int64, key *exprpb.Expr, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return &exprpb.Expr_CreateStruct_Entry{
Id: entryID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{MapKey: key},
Value: value,
OptionalEntry: optional,
}
func (p *parserHelper) newMapEntry(entryID int64, key ast.Expr, value ast.Expr, optional bool) ast.EntryExpr {
return p.exprFactory.NewMapEntry(entryID, key, value, optional)
}
func (p *parserHelper) newObject(ctx any, typeName string, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: typeName,
Entries: entries,
},
}
return exprNode
func (p *parserHelper) newObject(ctx any, typeName string, fields ...ast.EntryExpr) ast.Expr {
return p.exprFactory.NewStruct(p.newID(ctx), typeName, fields)
}
func (p *parserHelper) newObjectField(fieldID int64, field string, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return &exprpb.Expr_CreateStruct_Entry{
Id: fieldID,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: field},
Value: value,
OptionalEntry: optional,
}
func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Expr, optional bool) ast.EntryExpr {
return p.exprFactory.NewStructField(fieldID, field, value, optional)
}
func (p *parserHelper) newComprehension(ctx any, iterVar string,
iterRange *exprpb.Expr,
func (p *parserHelper) newComprehension(ctx any,
iterRange ast.Expr,
iterVar string,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr {
exprNode := p.newExpr(ctx)
exprNode.ExprKind = &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
AccuVar: accuVar,
AccuInit: accuInit,
IterVar: iterVar,
IterRange: iterRange,
LoopCondition: condition,
LoopStep: step,
Result: result}}
return exprNode
accuInit ast.Expr,
condition ast.Expr,
step ast.Expr,
result ast.Expr) ast.Expr {
return p.exprFactory.NewComprehension(
p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result)
}
func (p *parserHelper) newExpr(ctx any) *exprpb.Expr {
id, isID := ctx.(int64)
if isID {
return &exprpb.Expr{Id: id}
func (p *parserHelper) newID(ctx any) int64 {
if id, isID := ctx.(int64); isID {
return id
}
return &exprpb.Expr{Id: p.id(ctx)}
return p.id(ctx)
}
func (p *parserHelper) newExpr(ctx any) ast.Expr {
return p.exprFactory.NewUnspecifiedExpr(p.newID(ctx))
}
func (p *parserHelper) id(ctx any) int64 {
var location common.Location
var offset ast.OffsetRange
switch c := ctx.(type) {
case antlr.ParserRuleContext:
token := c.GetStart()
location = p.source.NewLocation(token.GetLine(), token.GetColumn())
start, stop := c.GetStart(), c.GetStop()
if stop == nil {
stop = start
}
offset.Start = p.sourceInfo.ComputeOffset(int32(start.GetLine()), int32(start.GetColumn()))
offset.Stop = p.sourceInfo.ComputeOffset(int32(stop.GetLine()), int32(stop.GetColumn()))
case antlr.Token:
token := c
location = p.source.NewLocation(token.GetLine(), token.GetColumn())
offset.Start = p.sourceInfo.ComputeOffset(int32(c.GetLine()), int32(c.GetColumn()))
offset.Stop = offset.Start
case common.Location:
location = c
offset.Start = p.sourceInfo.ComputeOffset(int32(c.Line()), int32(c.Column()))
offset.Stop = offset.Start
case ast.OffsetRange:
offset = c
default:
// This should only happen if the ctx is nil
return -1
}
id := p.nextID
p.positions[id], _ = p.source.LocationOffset(location)
p.sourceInfo.SetOffsetRange(id, offset)
p.nextID++
return id
}
func (p *parserHelper) getLocation(id int64) common.Location {
characterOffset := p.positions[id]
location, _ := p.source.OffsetLocation(characterOffset)
return location
return p.sourceInfo.GetStartLocation(id)
}
// buildMacroCallArg iterates the expression and returns a new expression
// where all macros have been replaced by their IDs in MacroCalls
func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr {
if _, found := p.macroCalls[expr.GetId()]; found {
return &exprpb.Expr{Id: expr.GetId()}
func (p *parserHelper) buildMacroCallArg(expr ast.Expr) ast.Expr {
if _, found := p.sourceInfo.GetMacroCall(expr.ID()); found {
return p.exprFactory.NewUnspecifiedExpr(expr.ID())
}
switch expr.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
switch expr.Kind() {
case ast.CallKind:
// Iterate the AST from `expr` recursively looking for macros. Because we are at most
// starting from the top level macro, this recursion is bounded by the size of the AST. This
// means that the depth check on the AST during parsing will catch recursion overflows
// before we get to here.
macroTarget := expr.GetCallExpr().GetTarget()
if macroTarget != nil {
macroTarget = p.buildMacroCallArg(macroTarget)
}
macroArgs := make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs()))
for index, arg := range expr.GetCallExpr().GetArgs() {
call := expr.AsCall()
macroArgs := make([]ast.Expr, len(call.Args()))
for index, arg := range call.Args() {
macroArgs[index] = p.buildMacroCallArg(arg)
}
return &exprpb.Expr{
Id: expr.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Target: macroTarget,
Function: expr.GetCallExpr().GetFunction(),
Args: macroArgs,
},
},
if !call.IsMemberFunction() {
return p.exprFactory.NewCall(expr.ID(), call.FunctionName(), macroArgs...)
}
case *exprpb.Expr_ListExpr:
listExpr := expr.GetListExpr()
macroListArgs := make([]*exprpb.Expr, len(listExpr.GetElements()))
for i, elem := range listExpr.GetElements() {
macroTarget := p.buildMacroCallArg(call.Target())
return p.exprFactory.NewMemberCall(expr.ID(), call.FunctionName(), macroTarget, macroArgs...)
case ast.ListKind:
list := expr.AsList()
macroListArgs := make([]ast.Expr, list.Size())
for i, elem := range list.Elements() {
macroListArgs[i] = p.buildMacroCallArg(elem)
}
return &exprpb.Expr{
Id: expr.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: macroListArgs,
OptionalIndices: listExpr.GetOptionalIndices(),
},
},
}
return p.exprFactory.NewList(expr.ID(), macroListArgs, list.OptionalIndices())
}
return expr
}
// addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target
// that are macros, their ID will be stored instead for later self-lookups.
func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) {
macroTarget := target
if target != nil {
if _, found := p.macroCalls[target.GetId()]; found {
macroTarget = &exprpb.Expr{Id: target.GetId()}
} else {
macroTarget = p.buildMacroCallArg(target)
}
}
macroArgs := make([]*exprpb.Expr, len(args))
func (p *parserHelper) addMacroCall(exprID int64, function string, target ast.Expr, args ...ast.Expr) {
macroArgs := make([]ast.Expr, len(args))
for index, arg := range args {
macroArgs[index] = p.buildMacroCallArg(arg)
}
p.macroCalls[exprID] = &exprpb.Expr{
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Target: macroTarget,
Function: function,
Args: macroArgs,
},
},
if target == nil {
p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewCall(0, function, macroArgs...))
return
}
macroTarget := target
if _, found := p.sourceInfo.GetMacroCall(target.ID()); found {
macroTarget = p.exprFactory.NewUnspecifiedExpr(target.ID())
} else {
macroTarget = p.buildMacroCallArg(target)
}
p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewMemberCall(0, function, macroTarget, macroArgs...))
}
// logicManager compacts logical trees into a more efficient structure which is semantically
@ -309,71 +234,71 @@ func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprp
// controversial choice as it alters the traditional order of execution assumptions present in most
// expressions.
type logicManager struct {
helper *parserHelper
exprFactory ast.ExprFactory
function string
terms []*exprpb.Expr
terms []ast.Expr
ops []int64
variadicASTs bool
}
// newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term.
func newVariadicLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager {
func newVariadicLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager {
return &logicManager{
helper: h,
exprFactory: fac,
function: function,
terms: []*exprpb.Expr{term},
terms: []ast.Expr{term},
ops: []int64{},
variadicASTs: true,
}
}
// newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term.
func newBalancingLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager {
func newBalancingLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager {
return &logicManager{
helper: h,
exprFactory: fac,
function: function,
terms: []*exprpb.Expr{term},
terms: []ast.Expr{term},
ops: []int64{},
variadicASTs: false,
}
}
// addTerm adds an operation identifier and term to the set of terms to be balanced.
func (l *logicManager) addTerm(op int64, term *exprpb.Expr) {
func (l *logicManager) addTerm(op int64, term ast.Expr) {
l.terms = append(l.terms, term)
l.ops = append(l.ops, op)
}
// toExpr renders the logic graph into an Expr value, either balancing a tree of logical
// operations or creating a variadic representation of the logical operator.
func (l *logicManager) toExpr() *exprpb.Expr {
func (l *logicManager) toExpr() ast.Expr {
if len(l.terms) == 1 {
return l.terms[0]
}
if l.variadicASTs {
return l.helper.newGlobalCall(l.ops[0], l.function, l.terms...)
return l.exprFactory.NewCall(l.ops[0], l.function, l.terms...)
}
return l.balancedTree(0, len(l.ops)-1)
}
// balancedTree recursively balances the terms provided to a commutative operator.
func (l *logicManager) balancedTree(lo, hi int) *exprpb.Expr {
func (l *logicManager) balancedTree(lo, hi int) ast.Expr {
mid := (lo + hi + 1) / 2
var left *exprpb.Expr
var left ast.Expr
if mid == lo {
left = l.terms[mid]
} else {
left = l.balancedTree(lo, mid-1)
}
var right *exprpb.Expr
var right ast.Expr
if mid == hi {
right = l.terms[mid+1]
} else {
right = l.balancedTree(mid+1, hi)
}
return l.helper.newGlobalCall(l.ops[mid], l.function, left, right)
return l.exprFactory.NewCall(l.ops[mid], l.function, left, right)
}
type exprHelper struct {
@ -387,202 +312,151 @@ func (e *exprHelper) nextMacroID() int64 {
// Copy implements the ExprHelper interface method by producing a copy of the input Expr value
// with a fresh set of numeric identifiers the Expr and all its descendants.
func (e *exprHelper) Copy(expr *exprpb.Expr) *exprpb.Expr {
copy := e.parserHelper.newExpr(e.parserHelper.getLocation(expr.GetId()))
switch expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
copy.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: expr.GetConstExpr()}
case *exprpb.Expr_IdentExpr:
copy.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: expr.GetIdentExpr()}
case *exprpb.Expr_SelectExpr:
op := expr.GetSelectExpr().GetOperand()
copy.ExprKind = &exprpb.Expr_SelectExpr{SelectExpr: &exprpb.Expr_Select{
Operand: e.Copy(op),
Field: expr.GetSelectExpr().GetField(),
TestOnly: expr.GetSelectExpr().GetTestOnly(),
}}
case *exprpb.Expr_CallExpr:
call := expr.GetCallExpr()
target := call.GetTarget()
if target != nil {
target = e.Copy(target)
func (e *exprHelper) Copy(expr ast.Expr) ast.Expr {
offsetRange, _ := e.parserHelper.sourceInfo.GetOffsetRange(expr.ID())
copyID := e.parserHelper.newID(offsetRange)
switch expr.Kind() {
case ast.LiteralKind:
return e.exprFactory.NewLiteral(copyID, expr.AsLiteral())
case ast.IdentKind:
return e.exprFactory.NewIdent(copyID, expr.AsIdent())
case ast.SelectKind:
sel := expr.AsSelect()
op := e.Copy(sel.Operand())
if sel.IsTestOnly() {
return e.exprFactory.NewPresenceTest(copyID, op, sel.FieldName())
}
args := call.GetArgs()
argsCopy := make([]*exprpb.Expr, len(args))
return e.exprFactory.NewSelect(copyID, op, sel.FieldName())
case ast.CallKind:
call := expr.AsCall()
args := call.Args()
argsCopy := make([]ast.Expr, len(args))
for i, arg := range args {
argsCopy[i] = e.Copy(arg)
}
copy.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: call.GetFunction(),
Target: target,
Args: argsCopy,
},
if !call.IsMemberFunction() {
return e.exprFactory.NewCall(copyID, call.FunctionName(), argsCopy...)
}
case *exprpb.Expr_ListExpr:
elems := expr.GetListExpr().GetElements()
elemsCopy := make([]*exprpb.Expr, len(elems))
return e.exprFactory.NewMemberCall(copyID, call.FunctionName(), e.Copy(call.Target()), argsCopy...)
case ast.ListKind:
list := expr.AsList()
elems := list.Elements()
elemsCopy := make([]ast.Expr, len(elems))
for i, elem := range elems {
elemsCopy[i] = e.Copy(elem)
}
copy.ExprKind = &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{Elements: elemsCopy},
return e.exprFactory.NewList(copyID, elemsCopy, list.OptionalIndices())
case ast.MapKind:
m := expr.AsMap()
entries := m.Entries()
entriesCopy := make([]ast.EntryExpr, len(entries))
for i, en := range entries {
entry := en.AsMapEntry()
entryID := e.nextMacroID()
entriesCopy[i] = e.exprFactory.NewMapEntry(entryID,
e.Copy(entry.Key()), e.Copy(entry.Value()), entry.IsOptional())
}
case *exprpb.Expr_StructExpr:
entries := expr.GetStructExpr().GetEntries()
entriesCopy := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
for i, entry := range entries {
entryCopy := &exprpb.Expr_CreateStruct_Entry{}
entryCopy.Id = e.nextMacroID()
switch entry.GetKeyKind().(type) {
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: entry.GetFieldKey(),
}
case *exprpb.Expr_CreateStruct_Entry_MapKey:
entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: e.Copy(entry.GetMapKey()),
}
}
entryCopy.Value = e.Copy(entry.GetValue())
entriesCopy[i] = entryCopy
}
copy.ExprKind = &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: expr.GetStructExpr().GetMessageName(),
Entries: entriesCopy,
},
}
case *exprpb.Expr_ComprehensionExpr:
iterRange := e.Copy(expr.GetComprehensionExpr().GetIterRange())
accuInit := e.Copy(expr.GetComprehensionExpr().GetAccuInit())
cond := e.Copy(expr.GetComprehensionExpr().GetLoopCondition())
step := e.Copy(expr.GetComprehensionExpr().GetLoopStep())
result := e.Copy(expr.GetComprehensionExpr().GetResult())
copy.ExprKind = &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterRange: iterRange,
IterVar: expr.GetComprehensionExpr().GetIterVar(),
AccuInit: accuInit,
AccuVar: expr.GetComprehensionExpr().GetAccuVar(),
LoopCondition: cond,
LoopStep: step,
Result: result,
},
return e.exprFactory.NewMap(copyID, entriesCopy)
case ast.StructKind:
s := expr.AsStruct()
fields := s.Fields()
fieldsCopy := make([]ast.EntryExpr, len(fields))
for i, f := range fields {
field := f.AsStructField()
fieldID := e.nextMacroID()
fieldsCopy[i] = e.exprFactory.NewStructField(fieldID,
field.Name(), e.Copy(field.Value()), field.IsOptional())
}
return e.exprFactory.NewStruct(copyID, s.TypeName(), fieldsCopy)
case ast.ComprehensionKind:
compre := expr.AsComprehension()
iterRange := e.Copy(compre.IterRange())
accuInit := e.Copy(compre.AccuInit())
cond := e.Copy(compre.LoopCondition())
step := e.Copy(compre.LoopStep())
result := e.Copy(compre.Result())
return e.exprFactory.NewComprehension(copyID,
iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result)
}
return copy
return e.exprFactory.NewUnspecifiedExpr(copyID)
}
// LiteralBool implements the ExprHelper interface method.
func (e *exprHelper) LiteralBool(value bool) *exprpb.Expr {
return e.parserHelper.newLiteralBool(e.nextMacroID(), value)
}
// LiteralBytes implements the ExprHelper interface method.
func (e *exprHelper) LiteralBytes(value []byte) *exprpb.Expr {
return e.parserHelper.newLiteralBytes(e.nextMacroID(), value)
}
// LiteralDouble implements the ExprHelper interface method.
func (e *exprHelper) LiteralDouble(value float64) *exprpb.Expr {
return e.parserHelper.newLiteralDouble(e.nextMacroID(), value)
}
// LiteralInt implements the ExprHelper interface method.
func (e *exprHelper) LiteralInt(value int64) *exprpb.Expr {
return e.parserHelper.newLiteralInt(e.nextMacroID(), value)
}
// LiteralString implements the ExprHelper interface method.
func (e *exprHelper) LiteralString(value string) *exprpb.Expr {
return e.parserHelper.newLiteralString(e.nextMacroID(), value)
}
// LiteralUint implements the ExprHelper interface method.
func (e *exprHelper) LiteralUint(value uint64) *exprpb.Expr {
return e.parserHelper.newLiteralUint(e.nextMacroID(), value)
// NewLiteral implements the ExprHelper interface method.
func (e *exprHelper) NewLiteral(value ref.Val) ast.Expr {
return e.exprFactory.NewLiteral(e.nextMacroID(), value)
}
// NewList implements the ExprHelper interface method.
func (e *exprHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr {
return e.parserHelper.newList(e.nextMacroID(), elems)
func (e *exprHelper) NewList(elems ...ast.Expr) ast.Expr {
return e.exprFactory.NewList(e.nextMacroID(), elems, []int32{})
}
// NewMap implements the ExprHelper interface method.
func (e *exprHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
return e.parserHelper.newMap(e.nextMacroID(), entries...)
func (e *exprHelper) NewMap(entries ...ast.EntryExpr) ast.Expr {
return e.exprFactory.NewMap(e.nextMacroID(), entries)
}
// NewMapEntry implements the ExprHelper interface method.
func (e *exprHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newMapEntry(e.nextMacroID(), key, val, optional)
func (e *exprHelper) NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr {
return e.exprFactory.NewMapEntry(e.nextMacroID(), key, val, optional)
}
// NewObject implements the ExprHelper interface method.
func (e *exprHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr {
return e.parserHelper.newObject(e.nextMacroID(), typeName, fieldInits...)
// NewStruct implements the ExprHelper interface method.
func (e *exprHelper) NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr {
return e.exprFactory.NewStruct(e.nextMacroID(), typeName, fieldInits)
}
// NewObjectFieldInit implements the ExprHelper interface method.
func (e *exprHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry {
return e.parserHelper.newObjectField(e.nextMacroID(), field, init, optional)
// NewStructField implements the ExprHelper interface method.
func (e *exprHelper) NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr {
return e.exprFactory.NewStructField(e.nextMacroID(), field, init, optional)
}
// Fold implements the ExprHelper interface method.
func (e *exprHelper) Fold(iterVar string,
iterRange *exprpb.Expr,
// NewComprehension implements the ExprHelper interface method.
func (e *exprHelper) NewComprehension(
iterRange ast.Expr,
iterVar string,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr {
return e.parserHelper.newComprehension(
e.nextMacroID(), iterVar, iterRange, accuVar, accuInit, condition, step, result)
accuInit ast.Expr,
condition ast.Expr,
step ast.Expr,
result ast.Expr) ast.Expr {
return e.exprFactory.NewComprehension(
e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result)
}
// Ident implements the ExprHelper interface method.
func (e *exprHelper) Ident(name string) *exprpb.Expr {
return e.parserHelper.newIdent(e.nextMacroID(), name)
// NewIdent implements the ExprHelper interface method.
func (e *exprHelper) NewIdent(name string) ast.Expr {
return e.exprFactory.NewIdent(e.nextMacroID(), name)
}
// AccuIdent implements the ExprHelper interface method.
func (e *exprHelper) AccuIdent() *exprpb.Expr {
return e.parserHelper.newIdent(e.nextMacroID(), AccumulatorName)
// NewAccuIdent implements the ExprHelper interface method.
func (e *exprHelper) NewAccuIdent() ast.Expr {
return e.exprFactory.NewAccuIdent(e.nextMacroID())
}
// GlobalCall implements the ExprHelper interface method.
func (e *exprHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr {
return e.parserHelper.newGlobalCall(e.nextMacroID(), function, args...)
// NewGlobalCall implements the ExprHelper interface method.
func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr {
return e.exprFactory.NewCall(e.nextMacroID(), function, args...)
}
// ReceiverCall implements the ExprHelper interface method.
func (e *exprHelper) ReceiverCall(function string,
target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
return e.parserHelper.newReceiverCall(e.nextMacroID(), function, target, args...)
// NewMemberCall implements the ExprHelper interface method.
func (e *exprHelper) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return e.exprFactory.NewMemberCall(e.nextMacroID(), function, target, args...)
}
// PresenceTest implements the ExprHelper interface method.
func (e *exprHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr {
return e.parserHelper.newPresenceTest(e.nextMacroID(), operand, field)
// NewPresenceTest implements the ExprHelper interface method.
func (e *exprHelper) NewPresenceTest(operand ast.Expr, field string) ast.Expr {
return e.exprFactory.NewPresenceTest(e.nextMacroID(), operand, field)
}
// Select implements the ExprHelper interface method.
func (e *exprHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr {
return e.parserHelper.newSelect(e.nextMacroID(), operand, field)
// NewSelect implements the ExprHelper interface method.
func (e *exprHelper) NewSelect(operand ast.Expr, field string) ast.Expr {
return e.exprFactory.NewSelect(e.nextMacroID(), operand, field)
}
// OffsetLocation implements the ExprHelper interface method.
func (e *exprHelper) OffsetLocation(exprID int64) common.Location {
offset, found := e.parserHelper.positions[exprID]
if !found {
return common.NoLocation
}
location, found := e.parserHelper.source.OffsetLocation(offset)
if !found {
return common.NoLocation
}
return location
return e.parserHelper.sourceInfo.GetStartLocation(exprID)
}
// NewError associates an error message with a given expression id, populating the source offset location of the error if possible.

View File

@ -15,7 +15,7 @@
package parser
import (
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
antlr "github.com/antlr4-go/antlr/v4"
"github.com/google/cel-go/common/runes"
)
@ -110,7 +110,7 @@ func (c *charStream) GetTextFromTokens(start, stop antlr.Token) string {
}
// GetTextFromInterval implements (antlr.CharStream).GetTextFromInterval.
func (c *charStream) GetTextFromInterval(i *antlr.Interval) string {
func (c *charStream) GetTextFromInterval(i antlr.Interval) string {
return c.GetText(i.Start, i.Stop)
}

View File

@ -18,9 +18,10 @@ import (
"fmt"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// NewGlobalMacro creates a Macro for a global function with the specified arg count.
@ -142,58 +143,38 @@ func makeVarArgMacroKey(name string, receiverStyle bool) string {
// and produces as output an Expr ast node.
//
// Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil.
type MacroExpander func(eh ExprHelper,
target *exprpb.Expr,
args []*exprpb.Expr) (*exprpb.Expr, *common.Error)
type MacroExpander func(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error)
// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is
// consistent with the source position and expression id generation code leveraged by both
// the parser and type-checker.
// ExprHelper assists with the creation of Expr values in a manner which is consistent
// the internal semantics and id generation behaviors of the parser and checker libraries.
type ExprHelper interface {
// Copy the input expression with a brand new set of identifiers.
Copy(*exprpb.Expr) *exprpb.Expr
Copy(ast.Expr) ast.Expr
// LiteralBool creates an Expr value for a bool literal.
LiteralBool(value bool) *exprpb.Expr
// Literal creates an Expr value for a scalar literal value.
NewLiteral(value ref.Val) ast.Expr
// LiteralBytes creates an Expr value for a byte literal.
LiteralBytes(value []byte) *exprpb.Expr
// LiteralDouble creates an Expr value for double literal.
LiteralDouble(value float64) *exprpb.Expr
// LiteralInt creates an Expr value for an int literal.
LiteralInt(value int64) *exprpb.Expr
// LiteralString creates am Expr value for a string literal.
LiteralString(value string) *exprpb.Expr
// LiteralUint creates an Expr value for a uint literal.
LiteralUint(value uint64) *exprpb.Expr
// NewList creates a CreateList instruction where the list is comprised of the optional set
// of elements provided as arguments.
NewList(elems ...*exprpb.Expr) *exprpb.Expr
// NewList creates a list literal instruction with an optional set of elements.
NewList(elems ...ast.Expr) ast.Expr
// NewMap creates a CreateStruct instruction for a map where the map is comprised of the
// optional set of key, value entries.
NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
NewMap(entries ...ast.EntryExpr) ast.Expr
// NewMapEntry creates a Map Entry for the key, value pair.
NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr
// NewObject creates a CreateStruct instruction for an object with a given type name and
// optional set of field initializers.
NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr
// NewStruct creates a struct literal expression with an optional set of field initializers.
NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr
// NewObjectFieldInit creates a new Object field initializer from the field name and value.
NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry
// NewStructField creates a new struct field initializer from the field name and value.
NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr
// Fold creates a fold comprehension instruction.
// NewComprehension creates a new comprehension instruction.
//
// - iterVar is the iteration variable name.
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - iterVar is the iteration variable name.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
@ -204,31 +185,31 @@ type ExprHelper interface {
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
Fold(iterVar string,
iterRange *exprpb.Expr,
NewComprehension(iterRange ast.Expr,
iterVar string,
accuVar string,
accuInit *exprpb.Expr,
condition *exprpb.Expr,
step *exprpb.Expr,
result *exprpb.Expr) *exprpb.Expr
accuInit ast.Expr,
condition ast.Expr,
step ast.Expr,
result ast.Expr) ast.Expr
// Ident creates an identifier Expr value.
Ident(name string) *exprpb.Expr
// NewIdent creates an identifier Expr value.
NewIdent(name string) ast.Expr
// AccuIdent returns an accumulator identifier for use with comprehension results.
AccuIdent() *exprpb.Expr
// NewAccuIdent returns an accumulator identifier for use with comprehension results.
NewAccuIdent() ast.Expr
// GlobalCall creates a function call Expr value for a global (free) function.
GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr
// NewCall creates a function call Expr value for a global (free) function.
NewCall(function string, args ...ast.Expr) ast.Expr
// ReceiverCall creates a function call Expr value for a receiver-style function.
ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr
// NewMemberCall creates a function call Expr value for a receiver-style function.
NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr
// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr
// NewPresenceTest creates a Select TestOnly Expr value for modelling has() semantics.
NewPresenceTest(operand ast.Expr, field string) ast.Expr
// Select create a field traversal Expr value.
Select(operand *exprpb.Expr, field string) *exprpb.Expr
// NewSelect create a field traversal Expr value.
NewSelect(operand ast.Expr, field string) ast.Expr
// OffsetLocation returns the Location of the expression identifier.
OffsetLocation(exprID int64) common.Location
@ -296,21 +277,21 @@ const (
// MakeAll expands the input call arguments into a comprehension that returns true if all of the
// elements in the range match the predicate expressions:
// <iterRange>.all(<iterVar>, <predicate>)
func MakeAll(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MakeAll(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
return makeQuantifier(quantifierAll, eh, target, args)
}
// MakeExists expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func MakeExists(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MakeExists(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
return makeQuantifier(quantifierExists, eh, target, args)
}
// MakeExistsOne expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MakeExistsOne(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
return makeQuantifier(quantifierExistsOne, eh, target, args)
}
@ -324,14 +305,14 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
return nil, eh.NewError(args[0].GetId(), "argument is not an identifier")
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
var fn *exprpb.Expr
var filter *exprpb.Expr
var fn ast.Expr
var filter ast.Expr
if len(args) == 3 {
filter = args[1]
@ -341,84 +322,83 @@ func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E
fn = args[1]
}
accuExpr := eh.Ident(AccumulatorName)
init := eh.NewList()
condition := eh.LiteralBool(true)
step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(fn))
condition := eh.NewLiteral(types.True)
step := eh.NewCall(operators.Add, eh.NewAccuIdent(), eh.NewList(fn))
if filter != nil {
step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr)
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
}
return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
}
// MakeFilter expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func MakeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
return nil, eh.NewError(args[0].GetId(), "argument is not an identifier")
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
filter := args[1]
accuExpr := eh.Ident(AccumulatorName)
init := eh.NewList()
condition := eh.LiteralBool(true)
step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(args[0]))
step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr)
return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil
condition := eh.NewLiteral(types.True)
step := eh.NewCall(operators.Add, eh.NewAccuIdent(), eh.NewList(args[0]))
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
}
// MakeHas expands the input call arguments into a presence test, e.g. has(<operand>.field)
func MakeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok {
return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil
func MakeHas(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
if args[0].Kind() == ast.SelectKind {
s := args[0].AsSelect()
return eh.NewPresenceTest(s.Operand(), s.FieldName()), nil
}
return nil, eh.NewError(args[0].GetId(), "invalid argument to has() macro")
return nil, eh.NewError(args[0].ID(), "invalid argument to has() macro")
}
func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
return nil, eh.NewError(args[0].GetId(), "argument must be a simple name")
return nil, eh.NewError(args[0].ID(), "argument must be a simple name")
}
var init *exprpb.Expr
var condition *exprpb.Expr
var step *exprpb.Expr
var result *exprpb.Expr
var init ast.Expr
var condition ast.Expr
var step ast.Expr
var result ast.Expr
switch kind {
case quantifierAll:
init = eh.LiteralBool(true)
condition = eh.GlobalCall(operators.NotStrictlyFalse, eh.AccuIdent())
step = eh.GlobalCall(operators.LogicalAnd, eh.AccuIdent(), args[1])
result = eh.AccuIdent()
init = eh.NewLiteral(types.True)
condition = eh.NewCall(operators.NotStrictlyFalse, eh.NewAccuIdent())
step = eh.NewCall(operators.LogicalAnd, eh.NewAccuIdent(), args[1])
result = eh.NewAccuIdent()
case quantifierExists:
init = eh.LiteralBool(false)
condition = eh.GlobalCall(
init = eh.NewLiteral(types.False)
condition = eh.NewCall(
operators.NotStrictlyFalse,
eh.GlobalCall(operators.LogicalNot, eh.AccuIdent()))
step = eh.GlobalCall(operators.LogicalOr, eh.AccuIdent(), args[1])
result = eh.AccuIdent()
eh.NewCall(operators.LogicalNot, eh.NewAccuIdent()))
step = eh.NewCall(operators.LogicalOr, eh.NewAccuIdent(), args[1])
result = eh.NewAccuIdent()
case quantifierExistsOne:
zeroExpr := eh.LiteralInt(0)
oneExpr := eh.LiteralInt(1)
zeroExpr := eh.NewLiteral(types.Int(0))
oneExpr := eh.NewLiteral(types.Int(1))
init = zeroExpr
condition = eh.LiteralBool(true)
step = eh.GlobalCall(operators.Conditional, args[1],
eh.GlobalCall(operators.Add, eh.AccuIdent(), oneExpr), eh.AccuIdent())
result = eh.GlobalCall(operators.Equals, eh.AccuIdent(), oneExpr)
condition = eh.NewLiteral(types.True)
step = eh.NewCall(operators.Conditional, args[1],
eh.NewCall(operators.Add, eh.NewAccuIdent(), oneExpr), eh.NewAccuIdent())
result = eh.NewCall(operators.Equals, eh.NewAccuIdent(), oneExpr)
default:
return nil, eh.NewError(args[0].GetId(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
}
return eh.Fold(v, target, AccumulatorName, init, condition, step, result), nil
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil
}
func extractIdent(e *exprpb.Expr) (string, bool) {
switch e.ExprKind.(type) {
case *exprpb.Expr_IdentExpr:
return e.GetIdentExpr().GetName(), true
func extractIdent(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.IdentKind:
return e.AsIdent(), true
}
return "", false
}

View File

@ -21,17 +21,15 @@ import (
"regexp"
"strconv"
"strings"
"sync"
antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
antlr "github.com/antlr4-go/antlr/v4"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/runes"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser/gen"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
// Parser encapsulates the context necessary to perform parsing for different expressions.
@ -88,11 +86,13 @@ func mustNewParser(opts ...Option) *Parser {
}
// Parse parses the expression represented by source and returns the result.
func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
fac := ast.NewExprFactory()
impl := parser{
errors: &parseErrors{errs},
helper: newParserHelper(source),
exprFactory: fac,
helper: newParserHelper(source, fac),
macros: p.macros,
maxRecursionDepth: p.maxRecursionDepth,
errorReportingLimit: p.errorReportingLimit,
@ -106,18 +106,15 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors
if !ok {
buf = runes.NewBuffer(source.Content())
}
var e *exprpb.Expr
var out ast.Expr
if buf.Len() > p.expressionSizeCodePointLimit {
e = impl.reportError(common.NoLocation,
out = impl.reportError(common.NoLocation,
"expression code point size exceeds limit: size: %d, limit %d",
buf.Len(), p.expressionSizeCodePointLimit)
} else {
e = impl.parse(buf, source.Description())
out = impl.parse(buf, source.Description())
}
return &exprpb.ParsedExpr{
Expr: e,
SourceInfo: impl.helper.getSourceInfo(),
}, errs
return ast.NewAST(out, impl.helper.getSourceInfo()), errs
}
// reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid
@ -150,7 +147,7 @@ var reservedIds = map[string]struct{}{
// This function calls ParseWithMacros with AllMacros.
//
// Deprecated: Use NewParser().Parse() instead.
func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
func Parse(source common.Source) (*ast.AST, *common.Errors) {
return mustNewParser(Macros(AllMacros...)).Parse(source)
}
@ -287,6 +284,7 @@ var _ antlr.ErrorStrategy = &recoveryLimitErrorStrategy{}
type parser struct {
gen.BaseCELVisitor
errors *parseErrors
exprFactory ast.ExprFactory
helper *parserHelper
macros map[string]Macro
recursionDepth int
@ -300,53 +298,21 @@ type parser struct {
enableVariadicOperatorASTs bool
}
var (
_ gen.CELVisitor = (*parser)(nil)
var _ gen.CELVisitor = (*parser)(nil)
lexerPool *sync.Pool = &sync.Pool{
New: func() any {
l := gen.NewCELLexer(nil)
l.RemoveErrorListeners()
return l
},
}
func (p *parser) parse(expr runes.Buffer, desc string) ast.Expr {
lexer := gen.NewCELLexer(newCharStream(expr, desc))
lexer.RemoveErrorListeners()
lexer.AddErrorListener(p)
parserPool *sync.Pool = &sync.Pool{
New: func() any {
p := gen.NewCELParser(nil)
p.RemoveErrorListeners()
return p
},
}
)
func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
// TODO: get rid of these pools once https://github.com/antlr/antlr4/pull/3571 is in a release
lexer := lexerPool.Get().(*gen.CELLexer)
prsr := parserPool.Get().(*gen.CELParser)
prsr := gen.NewCELParser(antlr.NewCommonTokenStream(lexer, 0))
prsr.RemoveErrorListeners()
prsrListener := &recursionListener{
maxDepth: p.maxRecursionDepth,
ruleTypeDepth: map[int]*int{},
}
defer func() {
// Unfortunately ANTLR Go runtime is missing (*antlr.BaseParser).RemoveParseListeners,
// so this is good enough until that is exported.
// Reset the lexer and parser before putting them back in the pool.
lexer.RemoveErrorListeners()
prsr.RemoveParseListener(prsrListener)
prsr.RemoveErrorListeners()
lexer.SetInputStream(nil)
prsr.SetInputStream(nil)
lexerPool.Put(lexer)
parserPool.Put(prsr)
}()
lexer.SetInputStream(newCharStream(expr, desc))
prsr.SetInputStream(antlr.NewCommonTokenStream(lexer, 0))
lexer.AddErrorListener(p)
prsr.AddErrorListener(p)
prsr.AddParseListener(prsrListener)
@ -373,7 +339,7 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
}
}()
return p.Visit(prsr.Start()).(*exprpb.Expr)
return p.Visit(prsr.Start_()).(ast.Expr)
}
// Visitor implementations.
@ -470,26 +436,26 @@ func (p *parser) VisitStart(ctx *gen.StartContext) any {
// Visit a parse tree produced by CELParser#expr.
func (p *parser) VisitExpr(ctx *gen.ExprContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
result := p.Visit(ctx.GetE()).(ast.Expr)
if ctx.GetOp() == nil {
return result
}
opID := p.helper.id(ctx.GetOp())
ifTrue := p.Visit(ctx.GetE1()).(*exprpb.Expr)
ifFalse := p.Visit(ctx.GetE2()).(*exprpb.Expr)
ifTrue := p.Visit(ctx.GetE1()).(ast.Expr)
ifFalse := p.Visit(ctx.GetE2()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.Conditional, result, ifTrue, ifFalse)
}
// Visit a parse tree produced by CELParser#conditionalOr.
func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
result := p.Visit(ctx.GetE()).(ast.Expr)
l := p.newLogicManager(operators.LogicalOr, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
return p.reportError(ctx, "unexpected character, wanted '||'")
}
next := p.Visit(rest[i]).(*exprpb.Expr)
next := p.Visit(rest[i]).(ast.Expr)
opID := p.helper.id(op)
l.addTerm(opID, next)
}
@ -498,14 +464,14 @@ func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
// Visit a parse tree produced by CELParser#conditionalAnd.
func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
result := p.Visit(ctx.GetE()).(ast.Expr)
l := p.newLogicManager(operators.LogicalAnd, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
return p.reportError(ctx, "unexpected character, wanted '&&'")
}
next := p.Visit(rest[i]).(*exprpb.Expr)
next := p.Visit(rest[i]).(ast.Expr)
opID := p.helper.id(op)
l.addTerm(opID, next)
}
@ -519,9 +485,9 @@ func (p *parser) VisitRelation(ctx *gen.RelationContext) any {
opText = ctx.GetOp().GetText()
}
if op, found := operators.Find(opText); found {
lhs := p.Visit(ctx.Relation(0)).(*exprpb.Expr)
lhs := p.Visit(ctx.Relation(0)).(ast.Expr)
opID := p.helper.id(ctx.GetOp())
rhs := p.Visit(ctx.Relation(1)).(*exprpb.Expr)
rhs := p.Visit(ctx.Relation(1)).(ast.Expr)
return p.globalCallOrMacro(opID, op, lhs, rhs)
}
return p.reportError(ctx, "operator not found")
@ -534,9 +500,9 @@ func (p *parser) VisitCalc(ctx *gen.CalcContext) any {
opText = ctx.GetOp().GetText()
}
if op, found := operators.Find(opText); found {
lhs := p.Visit(ctx.Calc(0)).(*exprpb.Expr)
lhs := p.Visit(ctx.Calc(0)).(ast.Expr)
opID := p.helper.id(ctx.GetOp())
rhs := p.Visit(ctx.Calc(1)).(*exprpb.Expr)
rhs := p.Visit(ctx.Calc(1)).(ast.Expr)
return p.globalCallOrMacro(opID, op, lhs, rhs)
}
return p.reportError(ctx, "operator not found")
@ -552,7 +518,7 @@ func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) any {
return p.Visit(ctx.Member())
}
opID := p.helper.id(ctx.GetOps()[0])
target := p.Visit(ctx.Member()).(*exprpb.Expr)
target := p.Visit(ctx.Member()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.LogicalNot, target)
}
@ -561,13 +527,13 @@ func (p *parser) VisitNegate(ctx *gen.NegateContext) any {
return p.Visit(ctx.Member())
}
opID := p.helper.id(ctx.GetOps()[0])
target := p.Visit(ctx.Member()).(*exprpb.Expr)
target := p.Visit(ctx.Member()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.Negate, target)
}
// VisitSelect visits a parse tree produced by CELParser#Select.
func (p *parser) VisitSelect(ctx *gen.SelectContext) any {
operand := p.Visit(ctx.Member()).(*exprpb.Expr)
operand := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil || ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
@ -588,7 +554,7 @@ func (p *parser) VisitSelect(ctx *gen.SelectContext) any {
// VisitMemberCall visits a parse tree produced by CELParser#MemberCall.
func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any {
operand := p.Visit(ctx.Member()).(*exprpb.Expr)
operand := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil {
return p.helper.newExpr(ctx)
@ -600,13 +566,13 @@ func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any {
// Visit a parse tree produced by CELParser#Index.
func (p *parser) VisitIndex(ctx *gen.IndexContext) any {
target := p.Visit(ctx.Member()).(*exprpb.Expr)
target := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
}
opID := p.helper.id(ctx.GetOp())
index := p.Visit(ctx.GetIndex()).(*exprpb.Expr)
index := p.Visit(ctx.GetIndex()).(ast.Expr)
operator := operators.Index
if ctx.GetOpt() != nil {
if !p.enableOptionalSyntax {
@ -630,7 +596,7 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any {
messageName = "." + messageName
}
objID := p.helper.id(ctx.GetOp())
entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]ast.EntryExpr)
return p.helper.newObject(objID, messageName, entries...)
}
@ -638,16 +604,16 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any {
func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) any {
if ctx == nil || ctx.GetFields() == nil {
// This is the result of a syntax error handled elswhere, return empty.
return []*exprpb.Expr_CreateStruct_Entry{}
return []ast.EntryExpr{}
}
result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetFields()))
result := make([]ast.EntryExpr, len(ctx.GetFields()))
cols := ctx.GetCols()
vals := ctx.GetValues()
for i, f := range ctx.GetFields() {
if i >= len(cols) || i >= len(vals) {
// This is the result of a syntax error detected elsewhere.
return []*exprpb.Expr_CreateStruct_Entry{}
return []ast.EntryExpr{}
}
initID := p.helper.id(cols[i])
optField := f.(*gen.OptFieldContext)
@ -659,10 +625,10 @@ func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext
// The field may be empty due to a prior error.
id := optField.IDENTIFIER()
if id == nil {
return []*exprpb.Expr_CreateStruct_Entry{}
return []ast.EntryExpr{}
}
fieldName := id.GetText()
value := p.Visit(vals[i]).(*exprpb.Expr)
value := p.Visit(vals[i]).(ast.Expr)
field := p.helper.newObjectField(initID, fieldName, value, optional)
result[i] = field
}
@ -702,9 +668,9 @@ func (p *parser) VisitCreateList(ctx *gen.CreateListContext) any {
// Visit a parse tree produced by CELParser#CreateStruct.
func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any {
structID := p.helper.id(ctx.GetOp())
entries := []*exprpb.Expr_CreateStruct_Entry{}
entries := []ast.EntryExpr{}
if ctx.GetEntries() != nil {
entries = p.Visit(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
entries = p.Visit(ctx.GetEntries()).([]ast.EntryExpr)
}
return p.helper.newMap(structID, entries...)
}
@ -713,17 +679,17 @@ func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any {
func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any {
if ctx == nil || ctx.GetKeys() == nil {
// This is the result of a syntax error handled elswhere, return empty.
return []*exprpb.Expr_CreateStruct_Entry{}
return []ast.EntryExpr{}
}
result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetCols()))
result := make([]ast.EntryExpr, len(ctx.GetCols()))
keys := ctx.GetKeys()
vals := ctx.GetValues()
for i, col := range ctx.GetCols() {
colID := p.helper.id(col)
if i >= len(keys) || i >= len(vals) {
// This is the result of a syntax error detected elsewhere.
return []*exprpb.Expr_CreateStruct_Entry{}
return []ast.EntryExpr{}
}
optKey := keys[i]
optional := optKey.GetOpt() != nil
@ -731,8 +697,8 @@ func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any
p.reportError(optKey, "unsupported syntax '?'")
continue
}
key := p.Visit(optKey.GetE()).(*exprpb.Expr)
value := p.Visit(vals[i]).(*exprpb.Expr)
key := p.Visit(optKey.GetE()).(ast.Expr)
value := p.Visit(vals[i]).(ast.Expr)
entry := p.helper.newMapEntry(colID, key, value, optional)
result[i] = entry
}
@ -812,30 +778,27 @@ func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) any {
// Visit a parse tree produced by CELParser#Null.
func (p *parser) VisitNull(ctx *gen.NullContext) any {
return p.helper.newLiteral(ctx,
&exprpb.Constant{
ConstantKind: &exprpb.Constant_NullValue{
NullValue: structpb.NullValue_NULL_VALUE}})
return p.helper.exprFactory.NewLiteral(p.helper.newID(ctx), types.NullValue)
}
func (p *parser) visitExprList(ctx gen.IExprListContext) []*exprpb.Expr {
func (p *parser) visitExprList(ctx gen.IExprListContext) []ast.Expr {
if ctx == nil {
return []*exprpb.Expr{}
return []ast.Expr{}
}
return p.visitSlice(ctx.GetE())
}
func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int32) {
func (p *parser) visitListInit(ctx gen.IListInitContext) ([]ast.Expr, []int32) {
if ctx == nil {
return []*exprpb.Expr{}, []int32{}
return []ast.Expr{}, []int32{}
}
elements := ctx.GetElems()
result := make([]*exprpb.Expr, len(elements))
result := make([]ast.Expr, len(elements))
optionals := []int32{}
for i, e := range elements {
ex := p.Visit(e.GetE()).(*exprpb.Expr)
ex := p.Visit(e.GetE()).(ast.Expr)
if ex == nil {
return []*exprpb.Expr{}, []int32{}
return []ast.Expr{}, []int32{}
}
result[i] = ex
if e.GetOpt() != nil {
@ -849,13 +812,13 @@ func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int3
return result, optionals
}
func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr {
func (p *parser) visitSlice(expressions []gen.IExprContext) []ast.Expr {
if expressions == nil {
return []*exprpb.Expr{}
return []ast.Expr{}
}
result := make([]*exprpb.Expr, len(expressions))
result := make([]ast.Expr, len(expressions))
for i, e := range expressions {
ex := p.Visit(e).(*exprpb.Expr)
ex := p.Visit(e).(ast.Expr)
result[i] = ex
}
return result
@ -870,24 +833,24 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string {
return text
}
func (p *parser) newLogicManager(function string, term *exprpb.Expr) *logicManager {
func (p *parser) newLogicManager(function string, term ast.Expr) *logicManager {
if p.enableVariadicOperatorASTs {
return newVariadicLogicManager(p.helper, function, term)
return newVariadicLogicManager(p.exprFactory, function, term)
}
return newBalancingLogicManager(p.helper, function, term)
return newBalancingLogicManager(p.exprFactory, function, term)
}
func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr {
func (p *parser) reportError(ctx any, format string, args ...any) ast.Expr {
var location common.Location
err := p.helper.newExpr(ctx)
switch c := ctx.(type) {
case common.Location:
location = c
case antlr.Token, antlr.ParserRuleContext:
location = p.helper.getLocation(err.GetId())
location = p.helper.getLocation(err.ID())
}
// Provide arguments to the report error.
p.errors.reportErrorAtID(err.GetId(), location, format, args...)
p.errors.reportErrorAtID(err.ID(), location, format, args...)
return err
}
@ -912,33 +875,33 @@ func (p *parser) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, l
}
}
func (p *parser) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
func (p *parser) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// Intentional
}
func (p *parser) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
func (p *parser) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// Intentional
}
func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) {
func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
// Intentional
}
func (p *parser) globalCallOrMacro(exprID int64, function string, args ...*exprpb.Expr) *exprpb.Expr {
func (p *parser) globalCallOrMacro(exprID int64, function string, args ...ast.Expr) ast.Expr {
if expr, found := p.expandMacro(exprID, function, nil, args...); found {
return expr
}
return p.helper.newGlobalCall(exprID, function, args...)
}
func (p *parser) receiverCallOrMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
func (p *parser) receiverCallOrMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) ast.Expr {
if expr, found := p.expandMacro(exprID, function, target, args...); found {
return expr
}
return p.helper.newReceiverCall(exprID, function, target, args...)
}
func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) (*exprpb.Expr, bool) {
func (p *parser) expandMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) (ast.Expr, bool) {
macro, found := p.macros[makeMacroKey(function, len(args), target != nil)]
if !found {
macro, found = p.macros[makeVarArgMacroKey(function, target != nil)]
@ -964,7 +927,7 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr,
return nil, false
}
if p.populateMacroCalls {
p.helper.addMacroCall(expr.GetId(), function, target, args...)
p.helper.addMacroCall(expr.ID(), function, target, args...)
}
return expr, true
}

View File

@ -20,9 +20,9 @@ import (
"strconv"
"strings"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
// Unparse takes an input expression and source position information and generates a human-readable
@ -39,7 +39,7 @@ import (
//
// This function optionally takes in one or more UnparserOption to alter the unparsing behavior, such as
// performing word wrapping on expressions.
func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) (string, error) {
func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (string, error) {
unparserOpts := &unparserOption{
wrapOnColumn: defaultWrapOnColumn,
wrapAfterColumnLimit: defaultWrapAfterColumnLimit,
@ -68,12 +68,12 @@ func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption)
// unparser visits an expression to reconstruct a human-readable string from an AST.
type unparser struct {
str strings.Builder
info *exprpb.SourceInfo
info *ast.SourceInfo
options *unparserOption
lastWrappedIndex int
}
func (un *unparser) visit(expr *exprpb.Expr) error {
func (un *unparser) visit(expr ast.Expr) error {
if expr == nil {
return errors.New("unsupported expression")
}
@ -81,27 +81,29 @@ func (un *unparser) visit(expr *exprpb.Expr) error {
if visited || err != nil {
return err
}
switch expr.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
switch expr.Kind() {
case ast.CallKind:
return un.visitCall(expr)
case *exprpb.Expr_ConstExpr:
case ast.LiteralKind:
return un.visitConst(expr)
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
return un.visitIdent(expr)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
return un.visitList(expr)
case *exprpb.Expr_SelectExpr:
case ast.MapKind:
return un.visitStructMap(expr)
case ast.SelectKind:
return un.visitSelect(expr)
case *exprpb.Expr_StructExpr:
return un.visitStruct(expr)
case ast.StructKind:
return un.visitStructMsg(expr)
default:
return fmt.Errorf("unsupported expression: %v", expr)
}
}
func (un *unparser) visitCall(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
func (un *unparser) visitCall(expr ast.Expr) error {
c := expr.AsCall()
fun := c.FunctionName()
switch fun {
// ternary operator
case operators.Conditional:
@ -141,10 +143,10 @@ func (un *unparser) visitCall(expr *exprpb.Expr) error {
}
}
func (un *unparser) visitCallBinary(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
args := c.GetArgs()
func (un *unparser) visitCallBinary(expr ast.Expr) error {
c := expr.AsCall()
fun := c.FunctionName()
args := c.Args()
lhs := args[0]
// add parens if the current operator is lower precedence than the lhs expr operator.
lhsParen := isComplexOperatorWithRespectTo(fun, lhs)
@ -168,9 +170,9 @@ func (un *unparser) visitCallBinary(expr *exprpb.Expr) error {
return un.visitMaybeNested(rhs, rhsParen)
}
func (un *unparser) visitCallConditional(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
args := c.GetArgs()
func (un *unparser) visitCallConditional(expr ast.Expr) error {
c := expr.AsCall()
args := c.Args()
// add parens if operand is a conditional itself.
nested := isSamePrecedence(operators.Conditional, args[0]) ||
isComplexOperator(args[0])
@ -196,13 +198,13 @@ func (un *unparser) visitCallConditional(expr *exprpb.Expr) error {
return un.visitMaybeNested(args[2], nested)
}
func (un *unparser) visitCallFunc(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
args := c.GetArgs()
if c.GetTarget() != nil {
nested := isBinaryOrTernaryOperator(c.GetTarget())
err := un.visitMaybeNested(c.GetTarget(), nested)
func (un *unparser) visitCallFunc(expr ast.Expr) error {
c := expr.AsCall()
fun := c.FunctionName()
args := c.Args()
if c.IsMemberFunction() {
nested := isBinaryOrTernaryOperator(c.Target())
err := un.visitMaybeNested(c.Target(), nested)
if err != nil {
return err
}
@ -223,17 +225,17 @@ func (un *unparser) visitCallFunc(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitCallIndex(expr *exprpb.Expr) error {
func (un *unparser) visitCallIndex(expr ast.Expr) error {
return un.visitCallIndexInternal(expr, "[")
}
func (un *unparser) visitCallOptIndex(expr *exprpb.Expr) error {
func (un *unparser) visitCallOptIndex(expr ast.Expr) error {
return un.visitCallIndexInternal(expr, "[?")
}
func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error {
c := expr.GetCallExpr()
args := c.GetArgs()
func (un *unparser) visitCallIndexInternal(expr ast.Expr, op string) error {
c := expr.AsCall()
args := c.Args()
nested := isBinaryOrTernaryOperator(args[0])
err := un.visitMaybeNested(args[0], nested)
if err != nil {
@ -248,10 +250,10 @@ func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error {
return nil
}
func (un *unparser) visitCallUnary(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
args := c.GetArgs()
func (un *unparser) visitCallUnary(expr ast.Expr) error {
c := expr.AsCall()
fun := c.FunctionName()
args := c.Args()
unmangled, found := operators.FindReverse(fun)
if !found {
return fmt.Errorf("cannot unmangle operator: %s", fun)
@ -261,35 +263,34 @@ func (un *unparser) visitCallUnary(expr *exprpb.Expr) error {
return un.visitMaybeNested(args[0], nested)
}
func (un *unparser) visitConst(expr *exprpb.Expr) error {
c := expr.GetConstExpr()
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
un.str.WriteString(strconv.FormatBool(c.GetBoolValue()))
case *exprpb.Constant_BytesValue:
func (un *unparser) visitConst(expr ast.Expr) error {
val := expr.AsLiteral()
switch val := val.(type) {
case types.Bool:
un.str.WriteString(strconv.FormatBool(bool(val)))
case types.Bytes:
// bytes constants are surrounded with b"<bytes>"
b := c.GetBytesValue()
un.str.WriteString(`b"`)
un.str.WriteString(bytesToOctets(b))
un.str.WriteString(bytesToOctets([]byte(val)))
un.str.WriteString(`"`)
case *exprpb.Constant_DoubleValue:
case types.Double:
// represent the float using the minimum required digits
d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64)
d := strconv.FormatFloat(float64(val), 'g', -1, 64)
un.str.WriteString(d)
if !strings.Contains(d, ".") {
un.str.WriteString(".0")
}
case *exprpb.Constant_Int64Value:
i := strconv.FormatInt(c.GetInt64Value(), 10)
case types.Int:
i := strconv.FormatInt(int64(val), 10)
un.str.WriteString(i)
case *exprpb.Constant_NullValue:
case types.Null:
un.str.WriteString("null")
case *exprpb.Constant_StringValue:
case types.String:
// strings will be double quoted with quotes escaped.
un.str.WriteString(strconv.Quote(c.GetStringValue()))
case *exprpb.Constant_Uint64Value:
un.str.WriteString(strconv.Quote(string(val)))
case types.Uint:
// uint literals have a 'u' suffix.
ui := strconv.FormatUint(c.GetUint64Value(), 10)
ui := strconv.FormatUint(uint64(val), 10)
un.str.WriteString(ui)
un.str.WriteString("u")
default:
@ -298,16 +299,16 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitIdent(expr *exprpb.Expr) error {
un.str.WriteString(expr.GetIdentExpr().GetName())
func (un *unparser) visitIdent(expr ast.Expr) error {
un.str.WriteString(expr.AsIdent())
return nil
}
func (un *unparser) visitList(expr *exprpb.Expr) error {
l := expr.GetListExpr()
elems := l.GetElements()
func (un *unparser) visitList(expr ast.Expr) error {
l := expr.AsList()
elems := l.Elements()
optIndices := make(map[int]bool, len(elems))
for _, idx := range l.GetOptionalIndices() {
for _, idx := range l.OptionalIndices() {
optIndices[int(idx)] = true
}
un.str.WriteString("[")
@ -327,20 +328,20 @@ func (un *unparser) visitList(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitOptSelect(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
args := c.GetArgs()
func (un *unparser) visitOptSelect(expr ast.Expr) error {
c := expr.AsCall()
args := c.Args()
operand := args[0]
field := args[1].GetConstExpr().GetStringValue()
return un.visitSelectInternal(operand, false, ".?", field)
field := args[1].AsLiteral().(types.String)
return un.visitSelectInternal(operand, false, ".?", string(field))
}
func (un *unparser) visitSelect(expr *exprpb.Expr) error {
sel := expr.GetSelectExpr()
return un.visitSelectInternal(sel.GetOperand(), sel.GetTestOnly(), ".", sel.GetField())
func (un *unparser) visitSelect(expr ast.Expr) error {
sel := expr.AsSelect()
return un.visitSelectInternal(sel.Operand(), sel.IsTestOnly(), ".", sel.FieldName())
}
func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op string, field string) error {
func (un *unparser) visitSelectInternal(operand ast.Expr, testOnly bool, op string, field string) error {
// handle the case when the select expression was generated by the has() macro.
if testOnly {
un.str.WriteString("has(")
@ -358,34 +359,25 @@ func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op
return nil
}
func (un *unparser) visitStruct(expr *exprpb.Expr) error {
s := expr.GetStructExpr()
// If the message name is non-empty, then this should be treated as message construction.
if s.GetMessageName() != "" {
return un.visitStructMsg(expr)
}
// Otherwise, build a map.
return un.visitStructMap(expr)
}
func (un *unparser) visitStructMsg(expr *exprpb.Expr) error {
m := expr.GetStructExpr()
entries := m.GetEntries()
un.str.WriteString(m.GetMessageName())
func (un *unparser) visitStructMsg(expr ast.Expr) error {
m := expr.AsStruct()
fields := m.Fields()
un.str.WriteString(m.TypeName())
un.str.WriteString("{")
for i, entry := range entries {
f := entry.GetFieldKey()
if entry.GetOptionalEntry() {
for i, f := range fields {
field := f.AsStructField()
f := field.Name()
if field.IsOptional() {
un.str.WriteString("?")
}
un.str.WriteString(f)
un.str.WriteString(": ")
v := entry.GetValue()
v := field.Value()
err := un.visit(v)
if err != nil {
return err
}
if i < len(entries)-1 {
if i < len(fields)-1 {
un.str.WriteString(", ")
}
}
@ -393,13 +385,14 @@ func (un *unparser) visitStructMsg(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
m := expr.GetStructExpr()
entries := m.GetEntries()
func (un *unparser) visitStructMap(expr ast.Expr) error {
m := expr.AsMap()
entries := m.Entries()
un.str.WriteString("{")
for i, entry := range entries {
k := entry.GetMapKey()
if entry.GetOptionalEntry() {
for i, e := range entries {
entry := e.AsMapEntry()
k := entry.Key()
if entry.IsOptional() {
un.str.WriteString("?")
}
err := un.visit(k)
@ -407,7 +400,7 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
return err
}
un.str.WriteString(": ")
v := entry.GetValue()
v := entry.Value()
err = un.visit(v)
if err != nil {
return err
@ -420,16 +413,15 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
return nil
}
func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) {
macroCalls := un.info.GetMacroCalls()
call, found := macroCalls[expr.GetId()]
func (un *unparser) visitMaybeMacroCall(expr ast.Expr) (bool, error) {
call, found := un.info.GetMacroCall(expr.ID())
if !found {
return false, nil
}
return true, un.visit(call)
}
func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error {
func (un *unparser) visitMaybeNested(expr ast.Expr, nested bool) error {
if nested {
un.str.WriteString("(")
}
@ -453,12 +445,12 @@ func isLeftRecursive(op string) bool {
// precedence of the (possible) operation represented in the input Expr.
//
// If the expr is not a Call, the result is false.
func isSamePrecedence(op string, expr *exprpb.Expr) bool {
if expr.GetCallExpr() == nil {
func isSamePrecedence(op string, expr ast.Expr) bool {
if expr.Kind() != ast.CallKind {
return false
}
c := expr.GetCallExpr()
other := c.GetFunction()
c := expr.AsCall()
other := c.FunctionName()
return operators.Precedence(op) == operators.Precedence(other)
}
@ -466,16 +458,16 @@ func isSamePrecedence(op string, expr *exprpb.Expr) bool {
// than the (possible) operation represented in the input Expr.
//
// If the expr is not a Call, the result is false.
func isLowerPrecedence(op string, expr *exprpb.Expr) bool {
c := expr.GetCallExpr()
other := c.GetFunction()
func isLowerPrecedence(op string, expr ast.Expr) bool {
c := expr.AsCall()
other := c.FunctionName()
return operators.Precedence(op) < operators.Precedence(other)
}
// Indicates whether the expr is a complex operator, i.e., a call expression
// with 2 or more arguments.
func isComplexOperator(expr *exprpb.Expr) bool {
if expr.GetCallExpr() != nil && len(expr.GetCallExpr().GetArgs()) >= 2 {
func isComplexOperator(expr ast.Expr) bool {
if expr.Kind() == ast.CallKind && len(expr.AsCall().Args()) >= 2 {
return true
}
return false
@ -484,19 +476,19 @@ func isComplexOperator(expr *exprpb.Expr) bool {
// Indicates whether it is a complex operation compared to another.
// expr is *not* considered complex if it is not a call expression or has
// less than two arguments, or if it has a higher precedence than op.
func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool {
if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 {
func isComplexOperatorWithRespectTo(op string, expr ast.Expr) bool {
if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 {
return false
}
return isLowerPrecedence(op, expr)
}
// Indicate whether this is a binary or ternary operator.
func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool {
if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 {
func isBinaryOrTernaryOperator(expr ast.Expr) bool {
if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 {
return false
}
_, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction())
_, isBinaryOp := operators.FindReverseBinaryOperator(expr.AsCall().FunctionName())
return isBinaryOp || isSamePrecedence(operators.Conditional, expr)
}

View File

@ -847,7 +847,7 @@ func (p *Profile) HasFileLines() bool {
// "[vdso]", [vsyscall]" and some others, see the code.
func (m *Mapping) Unsymbolizable() bool {
name := filepath.Base(m.File)
return strings.HasPrefix(name, "[") || strings.HasPrefix(name, "linux-vdso") || strings.HasPrefix(m.File, "/dev/dri/")
return strings.HasPrefix(name, "[") || strings.HasPrefix(name, "linux-vdso") || strings.HasPrefix(m.File, "/dev/dri/") || m.File == "//anon"
}
// Copy makes a fully independent copy of a profile.