rebase: update kubernetes to latest

updating the kubernetes release to the
latest in main go.mod

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna
2024-08-19 10:01:33 +02:00
committed by mergify[bot]
parent 63c4c05b35
commit 5a66991bb3
2173 changed files with 98906 additions and 61334 deletions

View File

@ -14,7 +14,6 @@ go_library(
"decorators.go",
"dispatcher.go",
"evalstate.go",
"formatting.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",

View File

@ -287,6 +287,9 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) {
// determine whether the type is unknown before returning.
obj, found := vars.ResolveName(nm)
if found {
if celErr, ok := obj.(*types.Err); ok {
return nil, celErr.Unwrap()
}
obj, isOpt, err := applyQualifiers(vars, obj, a.qualifiers)
if err != nil {
return nil, err

View File

@ -1,383 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type typeVerifier func(int64, ...ref.Type) (bool, error)
// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports
// any errors at compile time.
func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator {
return func(inter Interpretable) (Interpretable, error) {
call, ok := inter.(InterpretableCall)
if !ok {
return inter, nil
}
if call.OverloadID() != "string_format" {
return inter, nil
}
args := call.Args()
if len(args) != 2 {
return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args))
}
fmtStrInter, ok := args[0].(InterpretableConst)
if !ok {
return inter, nil
}
var fmtArgsInter InterpretableConstructor
fmtArgsInter, ok = args[1].(InterpretableConstructor)
if !ok {
return inter, nil
}
if fmtArgsInter.Type() != types.ListType {
// don't necessarily return an error since the list may be DynType
return inter, nil
}
formatStr := fmtStrInter.Value().Value().(string)
initVals := fmtArgsInter.InitVals()
formatCheck := &formatCheck{
args: initVals,
verifier: verifier,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
return nil, err
}
seenArgs := formatCheck.argsRequested
if len(initVals) > seenArgs {
return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals))
}
return inter, nil
}
}
type formatCheck struct {
args []Interpretable
argsRequested int
curArgIndex int64
enableCheckArgTypes bool
verifier typeVerifier
}
func (c *formatCheck) String(arg ref.Val, locale string) (string, error) {
valid, err := verifyString(c.args[c.curArgIndex], c.verifier)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps")
}
return "", nil
}
func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("integer clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("fixed-point clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("scientific clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers and bools can be formatted as binary")
}
return "", nil
}
func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers, byte buffers, and strings can be formatted as hex")
}
return "", nil
}
}
func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("octal clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.curArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *formatCheck) ArgSize() int64 {
return int64(len(c.args))
}
func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) {
subVerified, err := verifier(sub.ID(),
types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType)
if err != nil {
return false, err
}
if !subVerified {
return false, nil
}
con, ok := sub.(InterpretableConstructor)
if ok {
members := con.InitVals()
for _, m := range members {
// recursively verify if we're dealing with a list/map
verified, err := verifyString(m, verifier)
if err != nil {
return false, err
}
if !verified {
return false, nil
}
}
}
return true, nil
}
// FormatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type FormatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// FormatList is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type FormatList interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// ArgSize returns the length of the argument list.
ArgSize() int64
}
type clauseImpl func(ref.Val, string) (string, error)
// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.ArgSize() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", fmt.Errorf("could not parse formatting clause: %s", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", fmt.Errorf("error during formatting: %s", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}

View File

@ -125,7 +125,7 @@ func (test *evalTestOnly) Eval(ctx Activation) ref.Val {
val, err := test.Resolve(ctx)
// Return an error if the resolve step fails
if err != nil {
return types.WrapErr(err)
return types.LabelErrNode(test.id, types.WrapErr(err))
}
if optVal, isOpt := val.(*types.Optional); isOpt {
return types.Bool(optVal.HasValue())
@ -231,6 +231,7 @@ func (or *evalOr) Eval(ctx Activation) ref.Val {
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
err = types.LabelErrNode(or.id, err)
}
}
}
@ -273,6 +274,7 @@ func (and *evalAnd) Eval(ctx Activation) ref.Val {
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
err = types.LabelErrNode(and.id, err)
}
}
}
@ -377,7 +379,7 @@ func (zero *evalZeroArity) ID() int64 {
// Eval implements the Interpretable interface method.
func (zero *evalZeroArity) Eval(ctx Activation) ref.Val {
return zero.impl()
return types.LabelErrNode(zero.id, zero.impl())
}
// Function implements the InterpretableCall interface method.
@ -421,14 +423,14 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val {
// If the implementation is bound and the argument value has the right traits required to
// invoke it, then call the implementation.
if un.impl != nil && (un.trait == 0 || (!strict && types.IsUnknownOrError(argVal)) || argVal.Type().HasTrait(un.trait)) {
return un.impl(argVal)
return types.LabelErrNode(un.id, un.impl(argVal))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if argVal.Type().HasTrait(traits.ReceiverType) {
return argVal.(traits.Receiver).Receive(un.function, un.overload, []ref.Val{})
return types.LabelErrNode(un.id, argVal.(traits.Receiver).Receive(un.function, un.overload, []ref.Val{}))
}
return types.NewErr("no such overload: %s", un.function)
return types.NewErrWithNodeID(un.id, "no such overload: %s", un.function)
}
// Function implements the InterpretableCall interface method.
@ -479,14 +481,14 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val {
// If the implementation is bound and the argument value has the right traits required to
// invoke it, then call the implementation.
if bin.impl != nil && (bin.trait == 0 || (!strict && types.IsUnknownOrError(lVal)) || lVal.Type().HasTrait(bin.trait)) {
return bin.impl(lVal, rVal)
return types.LabelErrNode(bin.id, bin.impl(lVal, rVal))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if lVal.Type().HasTrait(traits.ReceiverType) {
return lVal.(traits.Receiver).Receive(bin.function, bin.overload, []ref.Val{rVal})
return types.LabelErrNode(bin.id, lVal.(traits.Receiver).Receive(bin.function, bin.overload, []ref.Val{rVal}))
}
return types.NewErr("no such overload: %s", bin.function)
return types.NewErrWithNodeID(bin.id, "no such overload: %s", bin.function)
}
// Function implements the InterpretableCall interface method.
@ -545,14 +547,14 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val {
// invoke it, then call the implementation.
arg0 := argVals[0]
if fn.impl != nil && (fn.trait == 0 || (!strict && types.IsUnknownOrError(arg0)) || arg0.Type().HasTrait(fn.trait)) {
return fn.impl(argVals...)
return types.LabelErrNode(fn.id, fn.impl(argVals...))
}
// Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the
// operand (arg0).
if arg0.Type().HasTrait(traits.ReceiverType) {
return arg0.(traits.Receiver).Receive(fn.function, fn.overload, argVals[1:])
return types.LabelErrNode(fn.id, arg0.(traits.Receiver).Receive(fn.function, fn.overload, argVals[1:]))
}
return types.NewErr("no such overload: %s", fn.function)
return types.NewErrWithNodeID(fn.id, "no such overload: %s %d", fn.function, fn.id)
}
// Function implements the InterpretableCall interface method.
@ -595,7 +597,7 @@ func (l *evalList) Eval(ctx Activation) ref.Val {
if l.hasOptionals && l.optionals[i] {
optVal, ok := elemVal.(*types.Optional)
if !ok {
return invalidOptionalElementInit(elemVal)
return types.LabelErrNode(l.id, invalidOptionalElementInit(elemVal))
}
if !optVal.HasValue() {
continue
@ -645,7 +647,7 @@ func (m *evalMap) Eval(ctx Activation) ref.Val {
if m.hasOptionals && m.optionals[i] {
optVal, ok := valVal.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(keyVal, valVal)
return types.LabelErrNode(m.id, invalidOptionalEntryInit(keyVal, valVal))
}
if !optVal.HasValue() {
delete(entries, keyVal)
@ -705,7 +707,7 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
if o.hasOptionals && o.optionals[i] {
optVal, ok := val.(*types.Optional)
if !ok {
return invalidOptionalEntryInit(field, val)
return types.LabelErrNode(o.id, invalidOptionalEntryInit(field, val))
}
if !optVal.HasValue() {
delete(fieldVals, field)
@ -715,7 +717,7 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
}
fieldVals[field] = val
}
return o.provider.NewValue(o.typeName, fieldVals)
return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals))
}
func (o *evalObj) InitVals() []Interpretable {
@ -921,7 +923,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.ConstantQualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -934,7 +936,7 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence
out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -964,7 +966,7 @@ func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Attribute.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -977,7 +979,7 @@ func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceO
out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -1001,7 +1003,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Qualifier.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else {
val = e.adapter.NativeToValue(out)
}
@ -1014,7 +1016,7 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly
out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
val = types.LabelErrNode(e.ID(), types.WrapErr(err))
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
@ -1157,12 +1159,12 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val {
}
if cBool {
if tErr != nil {
return types.WrapErr(tErr)
return types.LabelErrNode(cond.id, types.WrapErr(tErr))
}
return cond.adapter.NativeToValue(tVal)
}
if fErr != nil {
return types.WrapErr(fErr)
return types.LabelErrNode(cond.id, types.WrapErr(fErr))
}
return cond.adapter.NativeToValue(fVal)
}
@ -1202,7 +1204,7 @@ func (a *evalAttr) Adapter() types.Adapter {
func (a *evalAttr) Eval(ctx Activation) ref.Val {
v, err := a.attr.Resolve(ctx)
if err != nil {
return types.WrapErr(err)
return types.LabelErrNode(a.ID(), types.WrapErr(err))
}
return a.adapter.NativeToValue(v)
}

View File

@ -22,19 +22,13 @@ import (
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Interpreter generates a new Interpretable from a checked or unchecked expression.
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error)
// NewUncheckedInterpretable returns an Interpretable from a parsed expression
// and an optional list of InterpretableDecorator values.
NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error)
NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
@ -177,7 +171,7 @@ func NewInterpreter(dispatcher Dispatcher,
// NewIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewInterpretable(
checked *ast.CheckedAST,
checked *ast.AST,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newPlanner(
i.dispatcher,
@ -187,19 +181,5 @@ func (i *exprInterpreter) NewInterpretable(
i.container,
checked,
decorators...)
return p.Plan(checked.Expr)
}
// NewUncheckedIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewUncheckedInterpretable(
expr *exprpb.Expr,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newUncheckedPlanner(
i.dispatcher,
i.provider,
i.adapter,
i.attrFactory,
i.container,
decorators...)
return p.Plan(expr)
return p.Plan(checked.Expr())
}

View File

@ -23,15 +23,12 @@ import (
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value.
type interpretablePlanner interface {
// Plan generates an Interpretable value (or error) from the input proto Expr.
Plan(expr *exprpb.Expr) (Interpretable, error)
Plan(expr ast.Expr) (Interpretable, error)
}
// newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider,
@ -43,7 +40,7 @@ func newPlanner(disp Dispatcher,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
checked *ast.CheckedAST,
exprAST *ast.AST,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
@ -51,29 +48,8 @@ func newPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: checked.ReferenceMap,
typeMap: checked.TypeMap,
decorators: decorators,
}
}
// newUncheckedPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider,
// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in
// Select expressions are resolved lazily at evaluation time.
func newUncheckedPlanner(disp Dispatcher,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
provider: provider,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: make(map[int64]*ast.ReferenceInfo),
typeMap: make(map[int64]*types.Type),
refMap: exprAST.ReferenceMap(),
typeMap: exprAST.TypeMap(),
decorators: decorators,
}
}
@ -95,22 +71,24 @@ type planner struct {
// useful for layering functionality into the evaluation that is not natively understood by CEL,
// such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of
// repeated expressions.
func (p *planner) Plan(expr *exprpb.Expr) (Interpretable, error) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
func (p *planner) Plan(expr ast.Expr) (Interpretable, error) {
switch expr.Kind() {
case ast.CallKind:
return p.decorate(p.planCall(expr))
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
return p.decorate(p.planIdent(expr))
case *exprpb.Expr_SelectExpr:
return p.decorate(p.planSelect(expr))
case *exprpb.Expr_ListExpr:
return p.decorate(p.planCreateList(expr))
case *exprpb.Expr_StructExpr:
return p.decorate(p.planCreateStruct(expr))
case *exprpb.Expr_ComprehensionExpr:
return p.decorate(p.planComprehension(expr))
case *exprpb.Expr_ConstExpr:
case ast.LiteralKind:
return p.decorate(p.planConst(expr))
case ast.SelectKind:
return p.decorate(p.planSelect(expr))
case ast.ListKind:
return p.decorate(p.planCreateList(expr))
case ast.MapKind:
return p.decorate(p.planCreateMap(expr))
case ast.StructKind:
return p.decorate(p.planCreateStruct(expr))
case ast.ComprehensionKind:
return p.decorate(p.planComprehension(expr))
}
return nil, fmt.Errorf("unsupported expr: %v", expr)
}
@ -132,16 +110,16 @@ func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) {
}
// planIdent creates an Interpretable that resolves an identifier from an Activation.
func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) {
func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) {
// Establish whether the identifier is in the reference map.
if identRef, found := p.refMap[expr.GetId()]; found {
return p.planCheckedIdent(expr.GetId(), identRef)
if identRef, found := p.refMap[expr.ID()]; found {
return p.planCheckedIdent(expr.ID(), identRef)
}
// Create the possible attribute list for the unresolved reference.
ident := expr.GetIdentExpr()
ident := expr.AsIdent()
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.MaybeAttribute(expr.GetId(), ident.Name),
attr: p.attrFactory.MaybeAttribute(expr.ID(), ident),
}, nil
}
@ -174,20 +152,20 @@ func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Inter
// a) selects a field from a map or proto.
// b) creates a field presence test for a select within a has() macro.
// c) resolves the select expression to a namespaced identifier.
func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) {
// If the Select id appears in the reference map from the CheckedExpr proto then it is either
// a namespaced identifier or enum value.
if identRef, found := p.refMap[expr.GetId()]; found {
return p.planCheckedIdent(expr.GetId(), identRef)
if identRef, found := p.refMap[expr.ID()]; found {
return p.planCheckedIdent(expr.ID(), identRef)
}
sel := expr.GetSelectExpr()
sel := expr.AsSelect()
// Plan the operand evaluation.
op, err := p.Plan(sel.GetOperand())
op, err := p.Plan(sel.Operand())
if err != nil {
return nil, err
}
opType := p.typeMap[sel.GetOperand().GetId()]
opType := p.typeMap[sel.Operand().ID()]
// If the Select was marked TestOnly, this is a presence test.
//
@ -211,14 +189,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
}
// Build a qualifier for the attribute.
qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false)
qual, err := p.attrFactory.NewQualifier(opType, expr.ID(), sel.FieldName(), false)
if err != nil {
return nil, err
}
// Modify the attribute to be test-only.
if sel.GetTestOnly() {
if sel.IsTestOnly() {
attr = &evalTestOnly{
id: expr.GetId(),
id: expr.ID(),
InterpretableAttribute: attr,
}
}
@ -230,10 +208,10 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
// planCall creates a callable Interpretable while specializing for common functions and invocation
// patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in
// optimized Interpretable values.
func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
call := expr.GetCallExpr()
func (p *planner) planCall(expr ast.Expr) (Interpretable, error) {
call := expr.AsCall()
target, fnName, oName := p.resolveFunction(expr)
argCount := len(call.GetArgs())
argCount := len(call.Args())
var offset int
if target != nil {
argCount++
@ -248,7 +226,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
}
args[0] = arg
}
for i, argExpr := range call.GetArgs() {
for i, argExpr := range call.Args() {
arg, err := p.Plan(argExpr)
if err != nil {
return nil, err
@ -307,7 +285,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
}
// planCallZero generates a zero-arity callable Interpretable.
func (p *planner) planCallZero(expr *exprpb.Expr,
func (p *planner) planCallZero(expr ast.Expr,
function string,
overload string,
impl *functions.Overload) (Interpretable, error) {
@ -315,7 +293,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr,
return nil, fmt.Errorf("no such overload: %s()", function)
}
return &evalZeroArity{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
impl: impl.Function,
@ -323,7 +301,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr,
}
// planCallUnary generates a unary callable Interpretable.
func (p *planner) planCallUnary(expr *exprpb.Expr,
func (p *planner) planCallUnary(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -340,7 +318,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalUnary{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
arg: args[0],
@ -351,7 +329,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr,
}
// planCallBinary generates a binary callable Interpretable.
func (p *planner) planCallBinary(expr *exprpb.Expr,
func (p *planner) planCallBinary(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -368,7 +346,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalBinary{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
lhs: args[0],
@ -380,7 +358,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr,
}
// planCallVarArgs generates a variable argument callable Interpretable.
func (p *planner) planCallVarArgs(expr *exprpb.Expr,
func (p *planner) planCallVarArgs(expr ast.Expr,
function string,
overload string,
impl *functions.Overload,
@ -397,7 +375,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
nonStrict = impl.NonStrict
}
return &evalVarArgs{
id: expr.GetId(),
id: expr.ID(),
function: function,
overload: overload,
args: args,
@ -408,41 +386,41 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
}
// planCallEqual generates an equals (==) Interpretable.
func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalEq{
id: expr.GetId(),
id: expr.ID(),
lhs: args[0],
rhs: args[1],
}, nil
}
// planCallNotEqual generates a not equals (!=) Interpretable.
func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalNe{
id: expr.GetId(),
id: expr.ID(),
lhs: args[0],
rhs: args[1],
}, nil
}
// planCallLogicalAnd generates a logical and (&&) Interpretable.
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalAnd{
id: expr.GetId(),
id: expr.ID(),
terms: args,
}, nil
}
// planCallLogicalOr generates a logical or (||) Interpretable.
func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) {
return &evalOr{
id: expr.GetId(),
id: expr.ID(),
terms: args,
}, nil
}
// planCallConditional generates a conditional / ternary (c ? t : f) Interpretable.
func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) {
cond := args[0]
t := args[1]
var tAttr Attribute
@ -464,13 +442,13 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.ConditionalAttribute(expr.GetId(), cond, tAttr, fAttr),
attr: p.attrFactory.ConditionalAttribute(expr.ID(), cond, tAttr, fAttr),
}, nil
}
// planCallIndex either extends an attribute with the argument to the index operation, or creates
// a relative attribute based on the return of a function call or operation.
func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) {
func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) {
op := args[0]
ind := args[1]
opType := p.typeMap[op.ID()]
@ -489,11 +467,11 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona
var qual Qualifier
switch ind := ind.(type) {
case InterpretableConst:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional)
qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind.Value(), optional)
case InterpretableAttribute:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional)
qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind, optional)
default:
qual, err = p.relativeAttr(expr.GetId(), ind, optional)
qual, err = p.relativeAttr(expr.ID(), ind, optional)
}
if err != nil {
return nil, err
@ -505,10 +483,10 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona
}
// planCreateList generates a list construction Interpretable.
func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
list := expr.GetListExpr()
optionalIndices := list.GetOptionalIndices()
elements := list.GetElements()
func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) {
list := expr.AsList()
optionalIndices := list.OptionalIndices()
elements := list.Elements()
optionals := make([]bool, len(elements))
for _, index := range optionalIndices {
if index < 0 || index >= int32(len(elements)) {
@ -525,7 +503,7 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
elems[i] = elemVal
}
return &evalList{
id: expr.GetId(),
id: expr.ID(),
elems: elems,
optionals: optionals,
hasOptionals: len(optionals) != 0,
@ -534,31 +512,29 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
}
// planCreateStruct generates a map or object construction Interpretable.
func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
str := expr.GetStructExpr()
if len(str.MessageName) != 0 {
return p.planCreateObj(expr)
}
entries := str.GetEntries()
func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) {
m := expr.AsMap()
entries := m.Entries()
optionals := make([]bool, len(entries))
keys := make([]Interpretable, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
keyVal, err := p.Plan(entry.GetMapKey())
for i, e := range entries {
entry := e.AsMapEntry()
keyVal, err := p.Plan(entry.Key())
if err != nil {
return nil, err
}
keys[i] = keyVal
valVal, err := p.Plan(entry.GetValue())
valVal, err := p.Plan(entry.Value())
if err != nil {
return nil, err
}
vals[i] = valVal
optionals[i] = entry.GetOptionalEntry()
optionals[i] = entry.IsOptional()
}
return &evalMap{
id: expr.GetId(),
id: expr.ID(),
keys: keys,
vals: vals,
optionals: optionals,
@ -568,27 +544,28 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
}
// planCreateObj generates an object construction Interpretable.
func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
obj := expr.GetStructExpr()
typeName, defined := p.resolveTypeName(obj.GetMessageName())
func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) {
obj := expr.AsStruct()
typeName, defined := p.resolveTypeName(obj.TypeName())
if !defined {
return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName())
return nil, fmt.Errorf("unknown type: %s", obj.TypeName())
}
entries := obj.GetEntries()
optionals := make([]bool, len(entries))
fields := make([]string, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
fields[i] = entry.GetFieldKey()
val, err := p.Plan(entry.GetValue())
objFields := obj.Fields()
optionals := make([]bool, len(objFields))
fields := make([]string, len(objFields))
vals := make([]Interpretable, len(objFields))
for i, f := range objFields {
field := f.AsStructField()
fields[i] = field.Name()
val, err := p.Plan(field.Value())
if err != nil {
return nil, err
}
vals[i] = val
optionals[i] = entry.GetOptionalEntry()
optionals[i] = field.IsOptional()
}
return &evalObj{
id: expr.GetId(),
id: expr.ID(),
typeName: typeName,
fields: fields,
vals: vals,
@ -599,33 +576,33 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
}
// planComprehension generates an Interpretable fold operation.
func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
fold := expr.GetComprehensionExpr()
accu, err := p.Plan(fold.GetAccuInit())
func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) {
fold := expr.AsComprehension()
accu, err := p.Plan(fold.AccuInit())
if err != nil {
return nil, err
}
iterRange, err := p.Plan(fold.GetIterRange())
iterRange, err := p.Plan(fold.IterRange())
if err != nil {
return nil, err
}
cond, err := p.Plan(fold.GetLoopCondition())
cond, err := p.Plan(fold.LoopCondition())
if err != nil {
return nil, err
}
step, err := p.Plan(fold.GetLoopStep())
step, err := p.Plan(fold.LoopStep())
if err != nil {
return nil, err
}
result, err := p.Plan(fold.GetResult())
result, err := p.Plan(fold.Result())
if err != nil {
return nil, err
}
return &evalFold{
id: expr.GetId(),
accuVar: fold.AccuVar,
id: expr.ID(),
accuVar: fold.AccuVar(),
accu: accu,
iterVar: fold.IterVar,
iterVar: fold.IterVar(),
iterRange: iterRange,
cond: cond,
step: step,
@ -635,37 +612,8 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
}
// planConst generates a constant valued Interpretable.
func (p *planner) planConst(expr *exprpb.Expr) (Interpretable, error) {
val, err := p.constValue(expr.GetConstExpr())
if err != nil {
return nil, err
}
return NewConstValue(expr.GetId(), val), nil
}
// constValue converts a proto Constant value to a ref.Val.
func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return p.adapter.NativeToValue(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return p.adapter.NativeToValue(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return p.adapter.NativeToValue(c.GetDoubleValue()), nil
case *exprpb.Constant_DurationValue:
return p.adapter.NativeToValue(c.GetDurationValue().AsDuration()), nil
case *exprpb.Constant_Int64Value:
return p.adapter.NativeToValue(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return p.adapter.NativeToValue(c.GetNullValue()), nil
case *exprpb.Constant_StringValue:
return p.adapter.NativeToValue(c.GetStringValue()), nil
case *exprpb.Constant_TimestampValue:
return p.adapter.NativeToValue(c.GetTimestampValue().AsTime()), nil
case *exprpb.Constant_Uint64Value:
return p.adapter.NativeToValue(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unknown constant type: %v", c)
func (p *planner) planConst(expr ast.Expr) (Interpretable, error) {
return NewConstValue(expr.ID(), expr.AsLiteral()), nil
}
// resolveTypeName takes a qualified string constructed at parse time, applies the proto
@ -687,17 +635,20 @@ func (p *planner) resolveTypeName(typeName string) (string, bool) {
// - The target expression may only consist of ident and select expressions.
// - The function is declared in the environment using its fully-qualified name.
// - The fully-qualified function name matches the string serialized target value.
func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, string) {
func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) {
// Note: similar logic exists within the `checker/checker.go`. If making changes here
// please consider the impact on checker.go and consolidate implementations or mirror code
// as appropriate.
call := expr.GetCallExpr()
target := call.GetTarget()
fnName := call.GetFunction()
call := expr.AsCall()
var target ast.Expr = nil
if call.IsMemberFunction() {
target = call.Target()
}
fnName := call.FunctionName()
// Checked expressions always have a reference map entry, and _should_ have the fully qualified
// function name as the fnName value.
oRef, hasOverload := p.refMap[expr.GetId()]
oRef, hasOverload := p.refMap[expr.ID()]
if hasOverload {
if len(oRef.OverloadIDs) == 1 {
return target, fnName, oRef.OverloadIDs[0]
@ -771,16 +722,30 @@ func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (Interpre
// toQualifiedName converts an expression AST into a qualified name if possible, with a boolean
// 'found' value that indicates if the conversion is successful.
func (p *planner) toQualifiedName(operand *exprpb.Expr) (string, bool) {
func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) {
// If the checker identified the expression as an attribute by the type-checker, then it can't
// possibly be part of qualified name in a namespace.
_, isAttr := p.refMap[operand.GetId()]
_, isAttr := p.refMap[operand.ID()]
if isAttr {
return "", false
}
// Since functions cannot be both namespaced and receiver functions, if the operand is not an
// qualified variable name, return the (possibly) qualified name given the expressions.
return containers.ToQualifiedName(operand)
switch operand.Kind() {
case ast.IdentKind:
id := operand.AsIdent()
return id, true
case ast.SelectKind:
sel := operand.AsSelect()
// Test only expressions are not valid as qualified names.
if sel.IsTestOnly() {
return "", false
}
if qual, found := p.toQualifiedName(sel.Operand()); found {
return qual + "." + sel.FieldName(), true
}
}
return "", false
}
func stripLeadingDot(name string) string {

View File

@ -15,19 +15,18 @@
package interpreter
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
type astPruner struct {
expr *exprpb.Expr
macroCalls map[int64]*exprpb.Expr
ast.ExprFactory
expr ast.Expr
macroCalls map[int64]ast.Expr
state EvalState
nextExprID int64
}
@ -67,84 +66,44 @@ type astPruner struct {
// compiled and constant folded expressions, but is not willing to constant
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
ExprFactory: ast.NewExprFactory(),
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expr: newExpr,
SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls},
newInfo := ast.NewSourceInfo(nil)
for id, call := range pruner.macroCalls {
newInfo.SetMacroCall(id, call)
}
return ast.NewAST(newExpr, newInfo)
}
func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: val,
},
}
}
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) {
switch v := val.(type) {
case types.Bool:
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
return p.NewLiteral(id, val), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
durationString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true
case types.Timestamp:
timestampString := v.ConvertToType(types.StringType).(types.String)
return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true
}
// Attempt to build a list literal.
if list, isList := val.(traits.Lister); isList {
sz := list.Size().(types.Int)
elemExprs := make([]*exprpb.Expr, sz)
elemExprs := make([]ast.Expr, sz)
for i := types.Int(0); i < sz; i++ {
elem := list.Get(i)
if types.IsUnknownOrError(elem) {
@ -157,20 +116,13 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elemExprs,
},
},
}, true
return p.NewList(id, elemExprs, []int32{}), true
}
// Create a map literal if possible.
if mp, isMap := val.(traits.Mapper); isMap {
it := mp.Iterator()
entries := make([]*exprpb.Expr_CreateStruct_Entry, mp.Size().(types.Int))
entries := make([]ast.EntryExpr, mp.Size().(types.Int))
i := 0
for it.HasNext() != types.False {
key := it.Next()
@ -186,25 +138,12 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
if !ok {
return nil, false
}
entry := &exprpb.Expr_CreateStruct_Entry{
Id: p.nextID(),
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: keyExpr,
},
Value: valExpr,
}
entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false)
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
Entries: entries,
},
},
}, true
return p.NewMap(id, entries), true
}
// TODO(issues/377) To construct message literals, the type provider will need to support
@ -212,215 +151,206 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) {
elemVal, found := p.value(elem.ID())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}
func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) {
// elem in list
call := node.GetCallExpr()
val, exists := p.maybeValue(call.GetArgs()[1].GetId())
call := node.AsCall()
val, exists := p.maybeValue(call.Args()[1].ID())
if !exists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return nil, false
}
func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
arg := call.GetArgs()[0]
val, exists := p.maybeValue(arg.GetId())
func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
arg := call.Args()[0]
val, exists := p.maybeValue(arg.ID())
if !exists {
return nil, false
}
if b, ok := val.(types.Bool); ok {
return p.maybeCreateLiteral(node.GetId(), !b)
return p.maybeCreateLiteral(node.ID(), !b)
}
return nil, false
}
func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.GetArgs()[1], true
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
return p.maybeCreateLiteral(node.ID(), types.True)
}
return call.GetArgs()[0], true
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[0].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.GetArgs()[1], true
return call.Args()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v, exists := p.maybeValue(call.Args()[1].ID()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
return p.maybeCreateLiteral(node.ID(), types.False)
}
return call.GetArgs()[0], true
return call.Args()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
cond, exists := p.maybeValue(call.GetArgs()[0].GetId())
func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) {
call := node.AsCall()
cond, exists := p.maybeValue(call.Args()[0].ID())
if !exists {
return nil, false
}
if cond.Value().(bool) {
return call.GetArgs()[1], true
return call.Args()[1], true
}
return call.GetArgs()[2], true
return call.Args()[2], true
}
func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if _, exists := p.value(node.GetId()); !exists {
func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) {
if _, exists := p.value(node.ID()); !exists {
return nil, false
}
call := node.GetCallExpr()
if call.Function == operators.LogicalOr {
call := node.AsCall()
if call.FunctionName() == operators.LogicalOr {
return p.maybePruneOr(node)
}
if call.Function == operators.LogicalAnd {
if call.FunctionName() == operators.LogicalAnd {
return p.maybePruneAnd(node)
}
if call.Function == operators.Conditional {
if call.FunctionName() == operators.Conditional {
return p.maybePruneConditional(node)
}
if call.Function == operators.In {
if call.FunctionName() == operators.In {
return p.maybePruneIn(node)
}
if call.Function == operators.LogicalNot {
if call.FunctionName() == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}
func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) {
return p.prune(node)
}
func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.maybeValue(node.GetId())
val, valueExists := p.maybeValue(node.ID())
if valueExists {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
delete(p.macroCalls, node.GetId())
if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok {
delete(p.macroCalls, node.ID())
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
if macro, found := p.macroCalls[node.ID()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
compre := node.GetComprehensionExpr()
if compre != nil {
visit(macro, clearIterVarVisitor(compre.IterVar, p.state))
if node.Kind() == ast.ComprehensionKind {
compre := node.AsComprehension()
visit(macro, clearIterVarVisitor(compre.IterVar(), p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
p.macroCalls[node.ID()] = newMacro
}
}
// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
// more.
switch node.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: operand,
Field: node.GetSelectExpr().GetField(),
TestOnly: node.GetSelectExpr().GetTestOnly(),
},
},
}, true
switch node.Kind() {
case ast.SelectKind:
sel := node.AsSelect()
if operand, isPruned := p.maybePrune(sel.Operand()); isPruned {
if sel.IsTestOnly() {
return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true
}
return p.NewSelect(node.ID(), operand, sel.FieldName()), true
}
case *exprpb.Expr_CallExpr:
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
newArgs := make([]*exprpb.Expr, len(args))
newCall := &exprpb.Expr_Call{
Function: call.GetFunction(),
Target: call.GetTarget(),
Args: newArgs,
}
for i, arg := range args {
newArgs[i] = arg
if newArg, prunedArg := p.maybePrune(arg); prunedArg {
prunedCall = true
newArgs[i] = newArg
case ast.CallKind:
argsPruned := false
call := node.AsCall()
args := call.Args()
newArgs := make([]ast.Expr, len(args))
for i, a := range args {
newArgs[i] = a
if arg, isPruned := p.maybePrune(a); isPruned {
argsPruned = true
newArgs[i] = arg
}
}
if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget {
prunedCall = true
newCall.Target = newTarget
if !call.IsMemberFunction() {
newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
return newCall, argsPruned
}
newNode := &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
newTarget := call.Target()
targetPruned := false
if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned {
targetPruned = true
newTarget = prunedTarget
}
if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...)
if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned {
return prunedCall, true
}
if prunedCall {
return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
optIndices := node.GetListExpr().GetOptionalIndices()
return newCall, targetPruned || argsPruned
case ast.ListKind:
l := node.AsList()
elems := l.Elements()
optIndices := l.OptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool
newElems := make([]ast.Expr, 0, len(elems))
var listPruned bool
prunedIdx := 0
for i, elem := range elems {
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
prunedList = true
listPruned = true
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
@ -431,7 +361,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems = append(newElems, newElem)
prunedList = true
listPruned = true
} else {
newElems = append(newElems, elem)
}
@ -443,76 +373,64 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
if listPruned {
return p.NewList(node.ID(), newElems, optIndices), true
}
case *exprpb.Expr_StructExpr:
var prunedStruct bool
entries := node.GetStructExpr().GetEntries()
messageType := node.GetStructExpr().GetMessageName()
newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
case ast.MapKind:
var mapPruned bool
m := node.AsMap()
entries := m.Entries()
newEntries := make([]ast.EntryExpr, len(entries))
for i, entry := range entries {
newEntries[i] = entry
newKey, prunedKey := p.maybePrune(entry.GetMapKey())
newValue, prunedValue := p.maybePrune(entry.GetValue())
if !prunedKey && !prunedValue {
e := entry.AsMapEntry()
newKey, keyPruned := p.maybePrune(e.Key())
newValue, valuePruned := p.maybePrune(e.Value())
if !keyPruned && !valuePruned {
continue
}
prunedStruct = true
newEntry := &exprpb.Expr_CreateStruct_Entry{
Value: newValue,
}
if messageType != "" {
newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: entry.GetFieldKey(),
}
} else {
newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
mapPruned = true
newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional())
newEntries[i] = newEntry
}
if prunedStruct {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: messageType,
Entries: newEntries,
},
},
}, true
if mapPruned {
return p.NewMap(node.ID(), newEntries), true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
case ast.StructKind:
var structPruned bool
obj := node.AsStruct()
fields := obj.Fields()
newFields := make([]ast.EntryExpr, len(fields))
for i, field := range fields {
newFields[i] = field
f := field.AsStructField()
newValue, prunedValue := p.maybePrune(f.Value())
if !prunedValue {
continue
}
structPruned = true
newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional())
newFields[i] = newEntry
}
if structPruned {
return p.NewStruct(node.ID(), obj.TypeName(), newFields), true
}
case ast.ComprehensionKind:
compre := node.AsComprehension()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
if newRange, pruned := p.maybePrune(compre.IterRange()); pruned {
return p.NewComprehension(
node.ID(),
newRange,
compre.IterVar(),
compre.AccuVar(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result(),
), true
}
}
return node, false
@ -539,12 +457,12 @@ func (p *astPruner) nextID() int64 {
type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
visitExpr func(expr ast.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
visitEntry func(entry ast.EntryExpr)
}
func getMaxID(expr *exprpb.Expr) int64 {
func getMaxID(expr ast.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
@ -552,10 +470,9 @@ func getMaxID(expr *exprpb.Expr) int64 {
func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
ident := e.GetIdentExpr()
if ident != nil && ident.GetName() == varName {
state.SetValue(e.GetId(), nil)
visitExpr: func(e ast.Expr) {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
state.SetValue(e.ID(), nil)
}
},
}
@ -563,56 +480,63 @@ func clearIterVarVisitor(varName string, state EvalState) astVisitor {
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
visitExpr: func(e ast.Expr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
visitEntry: func(e ast.EntryExpr) {
if e.ID() >= *maxID {
*maxID = e.ID() + 1
}
},
}
}
func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
func visit(expr ast.Expr, visitor astVisitor) {
exprs := []ast.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
switch e.Kind() {
case ast.SelectKind:
exprs = append(exprs, e.AsSelect().Operand())
case ast.CallKind:
call := e.AsCall()
if call.Target() != nil {
exprs = append(exprs, call.Target())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs, call.Args()...)
case ast.ComprehensionKind:
compre := e.AsComprehension()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
compre.IterRange(),
compre.AccuInit(),
compre.LoopCondition(),
compre.LoopStep(),
compre.Result())
case ast.ListKind:
list := e.AsList()
exprs = append(exprs, list.Elements()...)
case ast.MapKind:
for _, entry := range e.AsMap().Entries() {
e := entry.AsMapEntry()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
exprs = append(exprs, e.Key())
exprs = append(exprs, e.Value())
}
case ast.StructKind:
for _, entry := range e.AsStruct().Fields() {
f := entry.AsStructField()
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
exprs = append(exprs, entry.GetValue())
exprs = append(exprs, f.Value())
}
}
}