rebase: bump the golang-dependencies group with 1 update

Bumps the golang-dependencies group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.16.0 to 0.17.0
- [Commits](https://github.com/golang/crypto/compare/v0.16.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot]
2023-12-18 20:31:00 +00:00
committed by mergify[bot]
parent 1ad79314f9
commit e5d9b68d36
398 changed files with 33924 additions and 10753 deletions

View File

@ -15,6 +15,7 @@ go_library(
"macro.go",
"options.go",
"program.go",
"validator.go",
],
importpath = "github.com/google/cel-go/cel",
visibility = ["//visibility:public"],
@ -22,15 +23,18 @@ go_library(
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
@ -72,6 +76,8 @@ go_test(
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -16,13 +16,14 @@ package cel
import (
"errors"
"fmt"
"sync"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/checker/decls"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@ -40,8 +41,8 @@ type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*exprpb.Reference
typeMap map[int64]*exprpb.Type
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
}
// Expr returns the proto serializable instance of the parsed/checked expression.
@ -60,21 +61,26 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
}
// ResultType returns the output type of the expression if the Ast has been type-checked, else
// returns decls.Dyn as the parse step cannot infer the type.
// returns chkdecls.Dyn as the parse step cannot infer the type.
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return decls.Dyn
return chkdecls.Dyn
}
return ast.typeMap[ast.expr.GetId()]
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
return chkdecls.Dyn
}
return t
}
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, err := ExprTypeToType(ast.ResultType())
if err != nil {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
}
return t
@ -87,22 +93,33 @@ func (ast *Ast) Source() Source {
}
// FormatType converts a type message into a string representation.
//
// Deprecated: prefer FormatCELType
func FormatType(t *exprpb.Type) string {
return checker.FormatCheckedType(t)
}
// FormatCELType formats a cel.Type value to a string representation.
//
// The type formatting is identical to FormatType.
func FormatCELType(t *Type) string {
return checker.FormatCELType(t)
}
// Env encapsulates the context necessary to perform parsing, type checking, or generation of
// evaluable programs for different expressions.
type Env struct {
Container *containers.Container
functions map[string]*functionDecl
declarations []*exprpb.Decl
variables []*decls.VariableDecl
functions map[string]*decls.FunctionDecl
macros []parser.Macro
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption
// Internal parser representation
prsr *parser.Parser
@ -154,8 +171,8 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
return nil, err
}
return (&Env{
declarations: []*exprpb.Decl{},
functions: map[string]*functionDecl{},
variables: []*decls.VariableDecl{},
functions: map[string]*decls.FunctionDecl{},
macros: []parser.Macro{},
Container: containers.DefaultContainer,
adapter: registry,
@ -163,14 +180,20 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}
// Check performs type-checking on the input Ast and yields a checked Ast and/or set of Issues.
// If any `ASTValidators` are configured on the environment, they will be applied after a valid
// type-check result. If any issues are detected, the validators will provide them on the
// output Issues object.
//
// Checking has failed if the returned Issues value and its Issues.Err() value are non-nil.
// Issues should be inspected if they are non-nil, but may not represent a fatal error.
// Either checking or validation has failed if the returned Issues value and its Issues.Err()
// value are non-nil. Issues should be inspected if they are non-nil, but may not represent a
// fatal error.
//
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
@ -183,21 +206,38 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssues(errs)
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
}
res, errs := checker.Check(pe, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssues(errs)
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
return &Ast{
ast = &Ast{
source: ast.Source(),
expr: res.GetExpr(),
info: res.GetSourceInfo(),
refMap: res.GetReferenceMap(),
typeMap: res.GetTypeMap()}, nil
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
for _, v := range e.validators {
if cv, ok := v.(ASTValidatorConfigurer); ok {
cv.Configure(vConfig)
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
}
if iss.Err() != nil {
return nil, iss
}
return ast, nil
}
// Compile combines the Parse and Check phases CEL program compilation to produce an Ast and
@ -255,7 +295,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
copy(chkOptsCopy, e.chkOpts)
// Copy the declarations if needed.
decsCopy := []*exprpb.Decl{}
varsCopy := []*decls.VariableDecl{}
if chk != nil {
// If the type-checker has already been instantiated, then the e.declarations have been
// validated within the chk instance.
@ -263,8 +303,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
} else {
// If the type-checker has not been instantiated, ensure the unvalidated declarations are
// provided to the extended Env instance.
decsCopy = make([]*exprpb.Decl, len(e.declarations))
copy(decsCopy, e.declarations)
varsCopy = make([]*decls.VariableDecl, len(e.variables))
copy(varsCopy, e.variables)
}
// Copy macros and program options
@ -276,8 +316,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// Copy the adapter / provider if they appear to be mutable.
adapter := e.adapter
provider := e.provider
adapterReg, isAdapterReg := e.adapter.(ref.TypeRegistry)
providerReg, isProviderReg := e.provider.(ref.TypeRegistry)
adapterReg, isAdapterReg := e.adapter.(*types.Registry)
providerReg, isProviderReg := e.provider.(*types.Registry)
// In most cases the provider and adapter will be a ref.TypeRegistry;
// however, in the rare cases where they are not, they are assumed to
// be immutable. Since it is possible to set the TypeProvider separately
@ -308,7 +348,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.appliedFeatures {
appliedFeaturesCopy[k] = v
}
funcsCopy := make(map[string]*functionDecl, len(e.functions))
funcsCopy := make(map[string]*decls.FunctionDecl, len(e.functions))
for k, v := range e.functions {
funcsCopy[k] = v
}
@ -316,10 +356,14 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.libraries {
libsCopy[k] = v
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)
ext := &Env{
Container: e.Container,
declarations: decsCopy,
variables: varsCopy,
functions: funcsCopy,
macros: macsCopy,
progOpts: progOptsCopy,
@ -327,9 +371,11 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
validators: validatorsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
@ -347,6 +393,25 @@ func (e *Env) HasLibrary(libName string) bool {
return exists && configured
}
// Libraries returns a list of SingletonLibrary that have been configured in the environment.
func (e *Env) Libraries() []string {
libraries := make([]string, 0, len(e.libraries))
for libName := range e.libraries {
libraries = append(libraries, libName)
}
return libraries
}
// HasValidator returns whether a specific ASTValidator has been configured in the environment.
func (e *Env) HasValidator(name string) bool {
for _, v := range e.validators {
if v.Name() == name {
return true
}
}
return false
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@ -388,36 +453,64 @@ func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return newProgram(e, ast, optSet)
}
// CELTypeAdapter returns the `types.Adapter` configured for the environment.
func (e *Env) CELTypeAdapter() types.Adapter {
return e.adapter
}
// CELTypeProvider returns the `types.Provider` configured for the environment.
func (e *Env) CELTypeProvider() types.Provider {
return e.provider
}
// TypeAdapter returns the `ref.TypeAdapter` configured for the environment.
//
// Deprecated: use CELTypeAdapter()
func (e *Env) TypeAdapter() ref.TypeAdapter {
return e.adapter
}
// TypeProvider returns the `ref.TypeProvider` configured for the environment.
//
// Deprecated: use CELTypeProvider()
func (e *Env) TypeProvider() ref.TypeProvider {
return e.provider
if legacyProvider, ok := e.provider.(ref.TypeProvider); ok {
return legacyProvider
}
return &interopLegacyTypeProvider{Provider: e.provider}
}
// UnknownVars returns an interpreter.PartialActivation which marks all variables
// declared in the Env as unknown AttributePattern values.
// UnknownVars returns an interpreter.PartialActivation which marks all variables declared in the
// Env as unknown AttributePattern values.
//
// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation
// unless the PartialAttributes option is provided as a ProgramOption.
// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) UnknownVars() interpreter.PartialActivation {
var unknownPatterns []*interpreter.AttributePattern
for _, d := range e.declarations {
switch d.GetDeclKind().(type) {
case *exprpb.Decl_Ident:
unknownPatterns = append(unknownPatterns,
interpreter.NewAttributePattern(d.GetName()))
}
}
part, _ := PartialVars(
interpreter.EmptyActivation(),
unknownPatterns...)
act := interpreter.EmptyActivation()
part, _ := PartialVars(act, e.computeUnknownVars(act)...)
return part
}
// PartialVars returns an interpreter.PartialActivation where all variables not in the input variable
// set, but which have been configured in the environment, are marked as unknown.
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
//
// Note, this is equivalent to calling cel.PartialVars and manually configuring the set of unknown
// variables. For more advanced use cases of partial state where portions of an object graph, rather
// than top-level variables, are missing the PartialVars() method may be a more suitable choice.
//
// Note, the PartialVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
act, err := interpreter.NewActivation(vars)
if err != nil {
return nil, err
}
return PartialVars(act, e.computeUnknownVars(act)...)
}
// ResidualAst takes an Ast and its EvalDetails to produce a new Ast which only contains the
// attribute references which are unknown.
//
@ -463,11 +556,16 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast)
if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
return checker.Cost(checked, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
}
// configure applies a series of EnvOptions to the current environment.
@ -488,14 +586,6 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return nil, err
}
// Initialize all of the functions configured within the environment.
for _, fn := range e.functions {
err = fn.init()
if err != nil {
return nil, err
}
}
// Configure the parser.
prsrOpts := []parser.Option{}
prsrOpts = append(prsrOpts, e.prsrOpts...)
@ -504,6 +594,9 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
if e.HasFeature(featureVariadicLogicalASTs) {
prsrOpts = append(prsrOpts, parser.EnableVariadicOperatorASTs(true))
}
e.prsr, err = parser.NewParser(prsrOpts...)
if err != nil {
return nil, err
@ -525,8 +618,6 @@ func (e *Env) initChecker() (*checker.Env, error) {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
chkOpts = append(chkOpts,
checker.HomogeneousAggregateLiterals(
e.HasFeature(featureDisableDynamicAggregateLiterals)),
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))
@ -536,19 +627,17 @@ func (e *Env) initChecker() (*checker.Env, error) {
return
}
// Add the statically configured declarations.
err = ce.Add(e.declarations...)
err = ce.AddIdents(e.variables...)
if err != nil {
e.setCheckerOrError(nil, err)
return
}
// Add the function declarations which are derived from the FunctionDecl instances.
for _, fn := range e.functions {
fnDecl, err := functionDeclToExprDecl(fn)
if err != nil {
e.setCheckerOrError(nil, err)
return
if fn.IsDeclarationDisabled() {
continue
}
err = ce.Add(fnDecl)
err = ce.AddFunctions(fn)
if err != nil {
e.setCheckerOrError(nil, err)
return
@ -596,17 +685,43 @@ func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
return e, nil
}
// computeUnknownVars determines a set of missing variables based on the input activation and the
// environment's configured declaration set.
func (e *Env) computeUnknownVars(vars interpreter.Activation) []*interpreter.AttributePattern {
var unknownPatterns []*interpreter.AttributePattern
for _, v := range e.variables {
varName := v.Name()
if _, found := vars.ResolveName(varName); found {
continue
}
unknownPatterns = append(unknownPatterns, interpreter.NewAttributePattern(varName))
}
return unknownPatterns
}
// Error type which references an expression id, a location within source, and a message.
type Error = common.Error
// Issues defines methods for inspecting the error details of parse and check calls.
//
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
}
// NewIssues returns an Issues struct from a common.Errors object.
func NewIssues(errs *common.Errors) *Issues {
return NewIssuesWithSourceInfo(errs, nil)
}
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
}
}
@ -622,9 +737,9 @@ func (i *Issues) Err() error {
}
// Errors returns the collection of errors encountered in more granular detail.
func (i *Issues) Errors() []common.Error {
func (i *Issues) Errors() []*Error {
if i == nil {
return []common.Error{}
return []*Error{}
}
return i.errs.GetErrors()
}
@ -648,6 +763,37 @@ func (i *Issues) String() string {
return i.errs.ToDisplayString()
}
// ReportErrorAtID reports an error message with an optional set of formatting arguments.
//
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}
// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// getStdEnv lazy initializes the CEL standard environment.
func getStdEnv() (*Env, error) {
stdEnvInit.Do(func() {
@ -656,6 +802,90 @@ func getStdEnv() (*Env, error) {
return stdEnv, stdEnvErr
}
// interopCELTypeProvider layers support for the types.Provider interface on top of a ref.TypeProvider.
type interopCELTypeProvider struct {
ref.TypeProvider
}
// FindStructType returns a types.Type instance for the given fully-qualified typeName if one exists.
//
// This method proxies to the underyling ref.TypeProvider's FindType method and converts protobuf type
// into a native type representation. If the conversion fails, the type is listed as not found.
func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if et, found := p.FindType(typeName); found {
t, err := types.ExprTypeToType(et)
if err != nil {
return nil, false
}
return t, true
}
return nil, false
}
// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//
// This method proxies to the underyling ref.TypeProvider's FindFieldType method and converts protobuf type
// into a native type representation. If the conversion fails, the type is listed as not found.
func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) {
if ft, found := p.FindFieldType(structType, fieldName); found {
t, err := types.ExprTypeToType(ft.Type)
if err != nil {
return nil, false
}
return &types.FieldType{
Type: t,
IsSet: ft.IsSet,
GetFrom: ft.GetFrom,
}, true
}
return nil, false
}
// interopLegacyTypeProvider layers support for the ref.TypeProvider interface on top of a types.Provider.
type interopLegacyTypeProvider struct {
types.Provider
}
// FindType retruns the protobuf Type representation for the input type name if one exists.
//
// This method proxies to the underlying types.Provider FindStructType method and converts the types.Type
// value to a protobuf Type representation.
//
// Failure to convert the type will result in the type not being found.
func (p *interopLegacyTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
if t, found := p.FindStructType(typeName); found {
et, err := types.TypeToExprType(t)
if err != nil {
return nil, false
}
return et, true
}
return nil, false
}
// FindFieldType returns the protobuf-based FieldType representation for the input type name and field,
// if one exists.
//
// This call proxies to the types.Provider FindStructFieldType method and converts the types.FIeldType
// value to a protobuf-based ref.FieldType representation if found.
//
// Failure to convert the FieldType will result in the field not being found.
func (p *interopLegacyTypeProvider) FindFieldType(structType, fieldName string) (*ref.FieldType, bool) {
if cft, found := p.FindStructFieldType(structType, fieldName); found {
et, err := types.TypeToExprType(cft.Type)
if err != nil {
return nil, false
}
return &ref.FieldType{
Type: et,
IsSet: cft.IsSet,
GetFrom: cft.GetFrom,
}, true
}
return nil, false
}
var (
stdEnvInit sync.Once
stdEnv *Env

View File

@ -22,6 +22,7 @@ import (
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -33,7 +34,8 @@ import (
// CheckedExprToAst converts a checked expression proto message to an Ast.
func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
return CheckedExprToAstWithSource(checkedExpr, nil)
checked, _ := CheckedExprToAstWithSource(checkedExpr, nil)
return checked
}
// CheckedExprToAstWithSource converts a checked expression proto message to an Ast,
@ -44,29 +46,18 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
// through future calls.
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) *Ast {
refMap := checkedExpr.GetReferenceMap()
if refMap == nil {
refMap = map[int64]*exprpb.Reference{}
}
typeMap := checkedExpr.GetTypeMap()
if typeMap == nil {
typeMap = map[int64]*exprpb.Type{}
}
si := checkedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
if src == nil {
src = common.NewInfoSource(si)
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedExpr.GetExpr(),
info: si,
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: refMap,
typeMap: typeMap,
}
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
}
// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@ -76,12 +67,13 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
return &exprpb.CheckedExpr{
Expr: a.Expr(),
SourceInfo: a.SourceInfo(),
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}, nil
}
return ast.CheckedASTToCheckedExpr(cAst)
}
// ParsedExprToAst converts a parsed expression proto message to an Ast.
@ -202,7 +194,7 @@ func RefValueToValue(res ref.Val) (*exprpb.Value, error) {
}
var (
typeNameToTypeValue = map[string]*types.TypeValue{
typeNameToTypeValue = map[string]ref.Val{
"bool": types.BoolType,
"bytes": types.BytesType,
"double": types.DoubleType,
@ -219,7 +211,7 @@ var (
)
// ValueToRefValue converts between exprpb.Value and ref.Val.
func ValueToRefValue(adapter ref.TypeAdapter, v *exprpb.Value) (ref.Val, error) {
func ValueToRefValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) {
switch v.Kind.(type) {
case *exprpb.Value_NullValue:
return types.NullValue, nil

View File

@ -15,19 +15,18 @@
package cel
import (
"math"
"strconv"
"strings"
"time"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -35,6 +34,7 @@ import (
const (
optMapMacro = "optMap"
optFlatMapMacro = "optFlatMap"
hasValueFunc = "hasValue"
optionalNoneFunc = "optional.none"
optionalOfFunc = "optional.of"
@ -106,44 +106,213 @@ func (stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// EnvOptions returns options for the standard CEL function declarations and macros.
// CompileOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption {
return []EnvOption{
Declarations(checker.StandardDeclarations()...),
func(e *Env) (*Env, error) {
var err error
for _, fn := range stdlib.Functions() {
existing, found := e.functions[fn.Name()]
if found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
}
}
e.functions[fn.Name()] = fn
}
return e, nil
},
func(e *Env) (*Env, error) {
e.variables = append(e.variables, stdlib.Types()...)
return e, nil
},
Macros(StandardMacros...),
}
}
// ProgramOptions returns function implementations for the standard CEL functions.
func (stdLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
Functions(functions.StandardOverloads()...),
return []ProgramOption{}
}
// OptionalTypes enable support for optional syntax and types in CEL.
//
// The optional value type makes it possible to express whether variables have
// been provided, whether a result has been computed, and in the future whether
// an object field path, map key value, or list index has a value.
//
// # Syntax Changes
//
// OptionalTypes are unlike other CEL extensions because they modify the CEL
// syntax itself, notably through the use of a `?` preceding a field name or
// index value.
//
// ## Field Selection
//
// The optional syntax in field selection is denoted as `obj.?field`. In other
// words, if a field is set, return `optional.of(obj.field)“, else
// `optional.none()`. The optional field selection is viral in the sense that
// after the first optional selection all subsequent selections or indices
// are treated as optional, i.e. the following expressions are equivalent:
//
// obj.?field.subfield
// obj.?field.?subfield
//
// ## Indexing
//
// Similar to field selection, the optional syntax can be used in index
// expressions on maps and lists:
//
// list[?0]
// map[?key]
//
// ## Optional Field Setting
//
// When creating map or message literals, if a field may be optionally set
// based on its presence, then placing a `?` before the field name or key
// will ensure the type on the right-hand side must be optional(T) where T
// is the type of the field or key-value.
//
// The following returns a map with the key expression set only if the
// subfield is present, otherwise an empty map is created:
//
// {?key: obj.?field.subfield}
//
// ## Optional Element Setting
//
// When creating list literals, an element in the list may be optionally added
// when the element expression is preceded by a `?`:
//
// [a, ?b, ?c] // return a list with either [a], [a, b], [a, b, c], or [a, c]
//
// # Optional.Of
//
// Create an optional(T) value of a given value with type T.
//
// optional.of(10)
//
// # Optional.OfNonZeroValue
//
// Create an optional(T) value of a given value with type T if it is not a
// zero-value. A zero-value the default empty value for any given CEL type,
// including empty protobuf message types. If the value is empty, the result
// of this call will be optional.none().
//
// optional.ofNonZeroValue([1, 2, 3]) // optional(list(int))
// optional.ofNonZeroValue([]) // optional.none()
// optional.ofNonZeroValue(0) // optional.none()
// optional.ofNonZeroValue("") // optional.none()
//
// # Optional.None
//
// Create an empty optional value.
//
// # HasValue
//
// Determine whether the optional contains a value.
//
// optional.of(b'hello').hasValue() // true
// optional.ofNonZeroValue({}).hasValue() // false
//
// # Value
//
// Get the value contained by the optional. If the optional does not have a
// value, the result will be a CEL error.
//
// optional.of(b'hello').value() // b'hello'
// optional.ofNonZeroValue({}).value() // error
//
// # Or
//
// If the value on the left-hand side is optional.none(), the optional value
// on the right hand side is returned. If the value on the left-hand set is
// valued, then it is returned. This operation is short-circuiting and will
// only evaluate as many links in the `or` chain as are needed to return a
// non-empty optional value.
//
// obj.?field.or(m[?key])
// l[?index].or(obj.?field.subfield).or(obj.?other)
//
// # OrValue
//
// Either return the value contained within the optional on the left-hand side
// or return the alternative value on the right hand side.
//
// m[?key].orValue("none")
//
// # OptMap
//
// Apply a transformation to the optional's underlying value if it is not empty
// and return an optional typed result based on the transformation. The
// transformation expression type must return a type T which is wrapped into
// an optional.
//
// msg.?elements.optMap(e, e.size()).orValue(0)
//
// # OptFlatMap
//
// Introduced in version: 1
//
// Apply a transformation to the optional's underlying value if it is not empty
// and return the result. The transform expression must return an optional(T)
// rather than type T. This can be useful when dealing with zero values and
// conditionally generating an empty or non-empty result in ways which cannot
// be expressed with `optMap`.
//
// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present.
func OptionalTypes(opts ...OptionalTypesOption) EnvOption {
lib := &optionalLib{version: math.MaxUint32}
for _, opt := range opts {
lib = opt(lib)
}
return Lib(lib)
}
type optionalLib struct {
version uint32
}
// OptionalTypesOption is a functional interface for configuring the strings library.
type OptionalTypesOption func(*optionalLib) *optionalLib
// OptionalTypesVersion configures the version of the optional type library.
//
// The version limits which functions are available. Only functions introduced
// below or equal to the given version included in the library. If this option
// is not set, all functions are available.
//
// See the library documentation to determine which version a function was introduced.
// If the documentation does not state which version a function was introduced, it can
// be assumed to be introduced at version 0, when the library was first created.
func OptionalTypesVersion(version uint32) OptionalTypesOption {
return func(lib *optionalLib) *optionalLib {
lib.version = version
return lib
}
}
type optionalLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (optionalLibrary) LibraryName() string {
func (lib *optionalLib) LibraryName() string {
return "cel.lib.optional"
}
// CompileOptions implements the Library interface method.
func (optionalLibrary) CompileOptions() []EnvOption {
func (lib *optionalLib) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
paramTypeV := TypeParamType("V")
optionalTypeV := OptionalType(paramTypeV)
listTypeV := ListType(paramTypeV)
mapTypeKV := MapType(paramTypeK, paramTypeV)
return []EnvOption{
opts := []EnvOption{
// Enable the optional syntax in the parser.
enableOptionalSyntax(),
// Introduce the optional type.
Types(types.OptionalType),
// Configure the optMap macro.
// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
@ -202,21 +371,29 @@ func (optionalLibrary) CompileOptions() []EnvOption {
// Index overloads to accommodate using an optional value as the operand.
Function(operators.Index,
Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("optional_map_index_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
// ProgramOptions implements the Library interface method.
func (lib *optionalLib) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "optMap() variable name must be a simple identifier",
Location: meh.OffsetLocation(varIdent.GetId()),
}
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
@ -237,11 +414,30 @@ func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exp
), nil
}
// ProgramOptions implements the Library interface method.
func (optionalLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
func enableOptionalSyntax() EnvOption {
@ -358,28 +554,16 @@ var (
timeOverloadDeclarations = []EnvOption{
Function(overloads.TimeGetHours,
MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Hours())
}))),
UnaryBinding(types.DurationGetHours))),
Function(overloads.TimeGetMinutes,
MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Minutes())
}))),
UnaryBinding(types.DurationGetMinutes))),
Function(overloads.TimeGetSeconds,
MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Seconds())
}))),
UnaryBinding(types.DurationGetSeconds))),
Function(overloads.TimeGetMilliseconds,
MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Milliseconds())
}))),
UnaryBinding(types.DurationGetMilliseconds))),
Function(overloads.TimeGetFullYear,
MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {

View File

@ -15,7 +15,6 @@
package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -63,21 +62,21 @@ func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
}
// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field)
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeHas(meh, target, args)
}
// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExists(meh, target, args)
}
// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExistsOne(meh, target, args)
}
@ -91,14 +90,14 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeMap(meh, target, args)
}
// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeFilter(meh, target, args)
}

View File

@ -23,12 +23,13 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -41,13 +42,6 @@ import (
const (
_ = iota
// Disallow heterogeneous aggregate (list, map) literals.
// Note, it is still possible to have heterogeneous aggregates when
// provided as variables to the expression, as well as via conversion
// of well-known dynamic types, or with unchecked expressions.
// Affects checking. Provides a subset of standard behavior.
featureDisableDynamicAggregateLiterals
// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking
@ -63,9 +57,10 @@ const (
// is not already in UTC.
featureDefaultUTCTimeZone
// Enable the use of optional types in the syntax, type-system, type-checking,
// and runtime.
featureOptionalTypes
// Enable the serialization of logical operator ASTs as variadic calls, thus
// compressing the logic graph to a single call when multiple like-operator
// expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d])
featureVariadicLogicalASTs
)
// EnvOption is a functional interface for configuring the environment.
@ -82,23 +77,26 @@ func ClearMacros() EnvOption {
}
}
// CustomTypeAdapter swaps the default ref.TypeAdapter implementation with a custom one.
// CustomTypeAdapter swaps the default types.Adapter implementation with a custom one.
//
// Note: This option must be specified before the Types and TypeDescs options when used together.
func CustomTypeAdapter(adapter ref.TypeAdapter) EnvOption {
func CustomTypeAdapter(adapter types.Adapter) EnvOption {
return func(e *Env) (*Env, error) {
e.adapter = adapter
return e, nil
}
}
// CustomTypeProvider swaps the default ref.TypeProvider implementation with a custom one.
// CustomTypeProvider replaces the types.Provider implementation with a custom one.
//
// The `provider` variable type may either be types.Provider or ref.TypeProvider (deprecated)
//
// Note: This option must be specified before the Types and TypeDescs options when used together.
func CustomTypeProvider(provider ref.TypeProvider) EnvOption {
func CustomTypeProvider(provider any) EnvOption {
return func(e *Env) (*Env, error) {
e.provider = provider
return e, nil
var err error
e.provider, err = maybeInteropProvider(provider)
return e, err
}
}
@ -108,8 +106,28 @@ func CustomTypeProvider(provider ref.TypeProvider) EnvOption {
// for the environment. The NewEnv call builds on top of the standard CEL declarations. For a
// purely custom set of declarations use NewCustomEnv.
func Declarations(decls ...*exprpb.Decl) EnvOption {
declOpts := []EnvOption{}
var err error
var opt EnvOption
// Convert the declarations to `EnvOption` values ahead of time.
// Surface any errors in conversion when the options are applied.
for _, d := range decls {
opt, err = ExprDeclToDeclaration(d)
if err != nil {
break
}
declOpts = append(declOpts, opt)
}
return func(e *Env) (*Env, error) {
e.declarations = append(e.declarations, decls...)
if err != nil {
return nil, err
}
for _, o := range declOpts {
e, err = o(e)
if err != nil {
return nil, err
}
}
return e, nil
}
}
@ -126,14 +144,25 @@ func EagerlyValidateDeclarations(enabled bool) EnvOption {
return features(featureEagerlyValidateDeclarations, enabled)
}
// HomogeneousAggregateLiterals option ensures that list and map literal entry types must agree
// during type-checking.
// HomogeneousAggregateLiterals disables mixed type list and map literal values.
//
// Note, it is still possible to have heterogeneous aggregates when provided as variables to the
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
return features(featureDisableDynamicAggregateLiterals, true)
return ASTValidators(ValidateHomogeneousAggregateLiterals())
}
// variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single
// variadic call with N-terms. This behavior is useful when serializing to a protocol buffer as
// it will reduce the number of recursive calls needed to deserialize the AST later.
//
// For example, given the following expression the call graph will be rendered accordingly:
//
// expression: a && b && c && (d || e)
// ast: call(_&&_, [a, b, c, call(_||_, [d, e])])
func variadicLogicalOperatorASTs() EnvOption {
return features(featureVariadicLogicalASTs, true)
}
// Macros option extends the macro set configured in the environment.
@ -226,7 +255,12 @@ func Abbrevs(qualifiedNames ...string) EnvOption {
// Note: This option must be specified after the CustomTypeProvider option when used together.
func Types(addTypes ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
var reg ref.TypeRegistry
var isReg bool
reg, isReg = e.provider.(*types.Registry)
if !isReg {
reg, isReg = e.provider.(ref.TypeRegistry)
}
if !isReg {
return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider)
}
@ -436,6 +470,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}
// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}
// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}
// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
@ -457,25 +509,21 @@ func CostLimit(costLimit uint64) ProgramOption {
}
}
func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) {
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
msgName := (string)(field.Message().FullName())
wellKnownType, found := pb.CheckedWellKnowns[msgName]
if found {
return wellKnownType, nil
}
return decls.NewObjectType(msgName), nil
return ObjectType(msgName), nil
}
if primitiveType, found := pb.CheckedPrimitives[field.Kind()]; found {
if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found {
return primitiveType, nil
}
if field.Kind() == protoreflect.EnumKind {
return decls.Int, nil
return IntType, nil
}
return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String())
}
func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
name := string(field.Name())
if field.IsMap() {
mapKey := field.MapKey()
@ -488,20 +536,20 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
return Variable(name, MapType(keyType, valueType)), nil
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
return Variable(name, ListType(elemType)), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
return Variable(name, celType), nil
}
// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
@ -509,25 +557,53 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment
func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return func(e *Env) (*Env, error) {
var decls []*exprpb.Decl
fields := descriptor.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
decl, err := fieldToDecl(field)
variable, err := fieldToVariable(field)
if err != nil {
return nil, err
}
e, err = variable(e)
if err != nil {
return nil, err
}
decls = append(decls, decl)
}
var err error
e, err = Declarations(decls...)(e)
if err != nil {
return nil, err
}
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}
// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation.
//
// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using
// protocol buffers.
func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
if ctx == nil || !ctx.ProtoReflect().IsValid() {
return interpreter.EmptyActivation(), nil
}
reg, err := types.NewRegistry(ctx)
if err != nil {
return nil, err
}
pbRef := ctx.ProtoReflect()
typeName := string(pbRef.Descriptor().FullName())
fields := pbRef.Descriptor().Fields()
vars := make(map[string]any, fields.Len())
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
sft, found := reg.FindStructFieldType(typeName, field.TextName())
if !found {
return nil, fmt.Errorf("no such field: %s", field.TextName())
}
fieldVal, err := sft.GetFrom(ctx)
if err != nil {
return nil, err
}
vars[field.TextName()] = fieldVal
}
return interpreter.NewActivation(vars)
}
// EnableMacroCallTracking ensures that call expressions which are replaced by macros
// are tracked in the `SourceInfo` of parsed and checked expressions.
func EnableMacroCallTracking() EnvOption {
@ -545,13 +621,6 @@ func DefaultUTCTimeZone(enabled bool) EnvOption {
return features(featureDefaultUTCTimeZone, enabled)
}
// OptionalTypes enable support for optional syntax and types in CEL. The optional value type makes
// it possible to express whether variables have been provided, whether a result has been computed,
// and in the future whether an object field path, map key value, or list index has a value.
func OptionalTypes() EnvOption {
return Lib(optionalLibrary{})
}
// features sets the given feature flags. See list of Feature constants above.
func features(flag int, enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
@ -577,3 +646,14 @@ func ParserExpressionSizeLimit(limit int) EnvOption {
return e, nil
}
}
func maybeInteropProvider(provider any) (types.Provider, error) {
switch p := provider.(type) {
case types.Provider:
return p, nil
case ref.TypeProvider:
return &interopCELTypeProvider{TypeProvider: p}, nil
default:
return nil, fmt.Errorf("unsupported type provider: %T", provider)
}
}

View File

@ -19,11 +19,10 @@ import (
"fmt"
"sync"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Program is an evaluable view of an Ast.
@ -62,6 +61,9 @@ func NoVars() interpreter.Activation {
// PartialVars returns a PartialActivation which contains variables and a set of AttributePattern
// values that indicate variables or parts of variables whose value are not yet known.
//
// This method relies on manually configured sets of missing attribute patterns. For a method which
// infers the missing variables from the input and the configured environment, use Env.PartialVars().
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
func PartialVars(vars any,
@ -104,7 +106,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
@ -128,10 +130,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}
func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)
return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
@ -153,9 +159,10 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}
// Configure the program via the ProgramOption values.
@ -169,7 +176,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Add the function bindings created via Function() options.
for _, fn := range e.functions {
bindings, err := fn.bindings()
bindings, err := fn.Bindings()
if err != nil {
return nil, err
}
@ -208,14 +215,11 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error)
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
t, err := ExprTypeToType(ast.typeMap[id])
if err != nil {
return false, err
}
if t.kind == DynKind {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
@ -223,7 +227,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if err != nil {
return false, err
}
if k == t.kind {
if t.Kind() == k {
return true, nil
}
}
@ -231,7 +235,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
@ -243,6 +247,12 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
@ -284,10 +294,11 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
var checked *exprpb.CheckedExpr
checked, err := AstToCheckedExpr(ast)
if err != nil {
return nil, err
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
if err != nil {
@ -371,7 +382,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
@ -384,7 +399,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
@ -412,7 +430,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
@ -498,7 +519,7 @@ type evalActivation struct {
// The lazy binding will only be invoked once per evaluation.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
// the types.Adapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (any, bool) {
v, found := a.vars[name]
if !found {

388
vendor/github.com/google/cel-go/cel/validator.go generated vendored Normal file
View File

@ -0,0 +1,388 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"reflect"
"regexp"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous"
// HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure
// the set of function names which are exempt from homogeneous type checks. The expected type
// is a string list of function names.
//
// As an example, the `<string>.format([args])` call expects the input arguments list to be
// comprised of a variety of types which correspond to the types expected by the format control
// clauses; however, all other uses of a mixed element type list, would be unexpected.
HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt"
)
// ASTValidators configures a set of ASTValidator instances into the target environment.
//
// Validators are applied in the order in which the are specified and are treated as singletons.
// The same ASTValidator with a given name will not be applied more than once.
func ASTValidators(validators ...ASTValidator) EnvOption {
return func(e *Env) (*Env, error) {
for _, v := range validators {
if !e.HasValidator(v.Name()) {
e.validators = append(e.validators, v)
}
}
return e, nil
}
}
// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment.
//
// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be
// reported to the caller.
type ASTValidator interface {
// Name returns the name of the validator. Names must be unique.
Name() string
// Validate validates a given Ast within an Environment and collects a set of potential issues.
//
// The ValidatorConfig is generated from the set of ASTValidatorConfigurer instances prior to
// the invocation of the Validate call. The expectation is that the validator configuration
// is created in sequence and immutable once provided to the Validate call.
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
type ValidatorConfig interface {
GetOrDefault(name string, value any) any
}
// MutableValidatorConfig provides mutation methods for querying and updating validator configuration
// settings.
type MutableValidatorConfig interface {
ValidatorConfig
Set(name string, value any) error
}
// ASTValidatorConfigurer indicates that this object, currently expected to be an ASTValidator,
// participates in validator configuration settings.
//
// This interface may be split from the expectation of being an ASTValidator instance in the future.
type ASTValidatorConfigurer interface {
Configure(MutableValidatorConfig) error
}
// validatorConfig implements the ValidatorConfig and MutableValidatorConfig interfaces.
type validatorConfig struct {
data map[string]any
}
// newValidatorConfig initializes the validator config with default values for core CEL validators.
func newValidatorConfig() *validatorConfig {
return &validatorConfig{
data: map[string]any{
HomogeneousAggregateLiteralExemptFunctions: []string{},
},
}
}
// GetOrDefault returns the configured value for the name, if present, else the input default value.
//
// Note, the type-agreement between the input default and configured value is not checked on read.
func (config *validatorConfig) GetOrDefault(name string, value any) any {
v, found := config.data[name]
if !found {
return value
}
return v
}
// Set configures a validator option with the given name and value.
//
// If the value had previously been set, the new value must have the same reflection type as the old one,
// or the call will error.
func (config *validatorConfig) Set(name string, value any) error {
v, found := config.data[name]
if found && reflect.TypeOf(v) != reflect.TypeOf(value) {
return fmt.Errorf("incompatible configuration type for %s, got %T, wanted %T", name, value, v)
}
config.data[name] = value
return nil
}
// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors.
//
// - Validate duration and timestamp literals
// - Ensure regex strings are valid
// - Disable mixed type list and map literals
func ExtendedValidations() EnvOption {
return ASTValidators(
ValidateDurationLiterals(),
ValidateTimestampLiterals(),
ValidateRegexLiterals(),
ValidateHomogeneousAggregateLiterals(),
)
}
// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check.
func ValidateDurationLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall)
}
// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check.
func ValidateTimestampLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall)
}
// ValidateRegexLiterals ensures that regex patterns are validated after type-check.
func ValidateRegexLiterals() ASTValidator {
return newFormatValidator(overloads.Matches, 0, compileRegex)
}
// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e.
// no mixed list element types or mixed map key or map value types.
//
// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all
// literals which occur within string format calls.
func ValidateHomogeneousAggregateLiterals() ASTValidator {
return homogeneousAggregateLiteralValidator{}
}
// ValidateComprehensionNestingLimit ensures that comprehension nesting does not exceed the specified limit.
//
// This validator can be useful for preventing arbitrarily nested comprehensions which can take high polynomial
// time to complete.
//
// Note, this limit does not apply to comprehensions with an empty iteration range, as these comprehensions have
// no actual looping cost. The cel.bind() utilizes the comprehension structure to perform local variable
// assignments and supplies an empty iteration range, so they won't count against the nesting limit either.
func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.NavigableExpr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
funcName: funcName,
check: check,
argNum: argNum,
}
}
type formatValidator struct {
funcName string
argNum int
check argChecker
}
// Name returns the unique name of this function format validator.
func (v formatValidator) Name() string {
return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
}
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
if len(callArgs) <= v.argNum {
continue
}
litArg := callArgs[v.argNum]
if litArg.Kind() != ast.LiteralKind {
continue
}
if err := v.check(e, call, litArg); err != nil {
iss.ReportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName)
}
}
}
func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
prg, err := env.Program(ast)
if err != nil {
return err
}
_, _, err = prg.Eval(NoVars())
return err
}
func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
}
type homogeneousAggregateLiteralValidator struct{}
// Name returns the unique name of the homogeneous type validator.
func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
continue
}
l := listExpr.AsList()
elements := l.Elements()
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
if elemType == nil {
elemType = et
continue
}
if !elemType.IsEquivalentType(et) {
v.typeMismatch(iss, e.ID(), elemType, et)
break
}
}
}
mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind))
for _, mapExpr := range mapExprs {
if inExemptFunction(mapExpr, exemptedFunctions) {
continue
}
m := mapExpr.AsMap()
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
keyType, valType = kt, vt
continue
}
if !keyType.IsEquivalentType(kt) {
v.typeMismatch(iss, key.ID(), keyType, kt)
}
if !valType.IsEquivalentType(vt) {
v.typeMismatch(iss, val.ID(), valType, vt)
}
}
}
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
if exempt == fnName {
return true
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
}
return false
}
func isOptionalIndex(i int, optIndices []int32) bool {
for _, optInd := range optIndices {
if i == int(optInd) {
return true
}
}
return false
}
func (homogeneousAggregateLiteralValidator) typeMismatch(iss *Issues, id int64, expected, actual *Type) {
iss.ReportErrorAtID(id, "expected type '%s' but found '%s'", FormatCELType(expected), FormatCELType(actual))
}
type nestingLimitValidator struct {
limit int
}
func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return
}
for _, comp := range comprehensions {
count := 0
e := comp
hasParent := true
for hasParent {
// When the expression is not a comprehension, continue to the next ancestor.
if e.Kind() != ast.ComprehensionKind {
e, hasParent = e.Parent()
continue
}
// When the comprehension has an empty range, continue to the next ancestor
// as this comprehension does not have any associated cost.
iterRange := e.AsComprehension().IterRange()
if iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 {
e, hasParent = e.Parent()
continue
}
// Otherwise check the nesting limit.
count++
if count > v.limit {
iss.ReportErrorAtID(comp.ID(), "comprehension exceeds nesting limit")
break
}
e, hasParent = e.Parent()
}
}
}

View File

@ -11,9 +11,11 @@ go_library(
"cost.go",
"env.go",
"errors.go",
"format.go",
"mapping.go",
"options.go",
"printer.go",
"scopes.go",
"standard.go",
"types.go",
],
@ -22,10 +24,13 @@ go_library(
deps = [
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/debug:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
@ -44,6 +49,7 @@ go_test(
"checker_test.go",
"cost_test.go",
"env_test.go",
"format_test.go",
],
embed = [
":go_default_library",

View File

@ -18,15 +18,13 @@ package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -37,8 +35,8 @@ type checker struct {
mappings *mapping
freeTypeVarCounter int
sourceInfo *exprpb.SourceInfo
types map[int64]*exprpb.Type
references map[int64]*exprpb.Reference
types map[int64]*types.Type
references map[int64]*ast.ReferenceInfo
}
// Check performs type checking, giving a typed AST.
@ -47,40 +45,38 @@ type checker struct {
// descriptions of protocol buffers, and a registry for errors.
// Returns a CheckedExpr proto, which might not be usable if
// there are errors in the error registry.
func Check(parsedExpr *exprpb.ParsedExpr,
source common.Source,
env *Env) (*exprpb.CheckedExpr, *common.Errors) {
func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) {
errs := common.NewErrors(source)
c := checker{
env: env,
errors: &typeErrors{common.NewErrors(source)},
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
sourceInfo: parsedExpr.GetSourceInfo(),
types: make(map[int64]*exprpb.Type),
references: make(map[int64]*exprpb.Reference),
types: make(map[int64]*types.Type),
references: make(map[int64]*ast.ReferenceInfo),
}
c.check(parsedExpr.GetExpr())
// Walk over the final type map substituting any type parameters either by their bound value or
// by DYN.
m := make(map[int64]*exprpb.Type)
for k, v := range c.types {
m[k] = substitute(c.mappings, v, true)
m := make(map[int64]*types.Type)
for id, t := range c.types {
m[id] = substitute(c.mappings, t, true)
}
return &exprpb.CheckedExpr{
return &ast.CheckedAST{
Expr: parsedExpr.GetExpr(),
SourceInfo: parsedExpr.GetSourceInfo(),
TypeMap: m,
ReferenceMap: c.references,
}, c.errors.Errors
}, errs
}
func (c *checker) check(e *exprpb.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
@ -113,53 +109,51 @@ func (c *checker) check(e *exprpb.Expr) {
case *exprpb.Expr_ComprehensionExpr:
c.checkComprehension(e)
default:
c.errors.ReportError(
c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e))
c.errors.unexpectedASTType(e.GetId(), c.location(e), e)
}
}
func (c *checker) checkInt64Literal(e *exprpb.Expr) {
c.setType(e, decls.Int)
c.setType(e, types.IntType)
}
func (c *checker) checkUint64Literal(e *exprpb.Expr) {
c.setType(e, decls.Uint)
c.setType(e, types.UintType)
}
func (c *checker) checkStringLiteral(e *exprpb.Expr) {
c.setType(e, decls.String)
c.setType(e, types.StringType)
}
func (c *checker) checkBytesLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bytes)
c.setType(e, types.BytesType)
}
func (c *checker) checkDoubleLiteral(e *exprpb.Expr) {
c.setType(e, decls.Double)
c.setType(e, types.DoubleType)
}
func (c *checker) checkBoolLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bool)
c.setType(e, types.BoolType)
}
func (c *checker) checkNullLiteral(e *exprpb.Expr) {
c.setType(e, decls.Null)
c.setType(e, types.NullType)
}
func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.GetName()
identExpr.Name = ident.Name()
return
}
c.setType(e, decls.Error)
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), identExpr.GetName())
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName())
}
func (c *checker) checkSelect(e *exprpb.Expr) {
@ -174,9 +168,9 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
identName := ident.GetName()
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
identName := ident.Name()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: identName,
@ -188,7 +182,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = decls.Bool
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
@ -200,16 +194,17 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) {
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field)
c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *exprpb.Type {
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
@ -218,38 +213,37 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
case kindMap:
resultType := types.ErrorType
switch targetType.Kind() {
case types.MapKind:
// Maps yield their value type as the selection result type.
mapType := targetType.GetMapType()
resultType = mapType.GetValueType()
case kindObject:
resultType = targetType.Parameters()[1]
case types.StructKind:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found {
resultType = fieldType.Type
if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found {
resultType = fieldType
}
case kindTypeParam:
case types.TypeParamKind:
// Set the operand type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, targetType)
c.isAssignable(types.DynType, targetType)
// Also, set the result type to DYN.
resultType = decls.Dyn
resultType = types.DynType
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType)
}
resultType = decls.Dyn
resultType = types.DynType
}
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return decls.NewOptionalType(resultType)
return types.NewOptionalType(resultType)
}
return resultType
}
@ -277,15 +271,14 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), fnName)
c.setType(e, decls.Error)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
call.Function = fn.GetName()
call.Function = fn.Name()
// Check to see whether the overload resolves.
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
c.resolveOverloadOrError(e, fn, nil, args)
return
}
@ -303,8 +296,8 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
call.Target = nil
call.Function = fn.GetName()
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
call.Function = fn.Name()
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
@ -314,22 +307,21 @@ func (c *checker) checkCall(e *exprpb.Expr) {
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
c.resolveOverloadOrError(c.location(e), e, fn, target, args)
c.resolveOverloadOrError(e, fn, target, args)
return
}
// Function name not declared, record error.
c.errors.undeclaredReference(c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
loc common.Location,
e *exprpb.Expr,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) {
e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(loc, fn, target, args)
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
if resolution == nil {
c.setType(e, decls.Error)
c.setType(e, types.ErrorType)
return
}
// Overload found.
@ -338,10 +330,9 @@ func (c *checker) resolveOverloadOrError(
}
func (c *checker) resolveOverload(
loc common.Location,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
var argTypes []*exprpb.Type
var argTypes []*types.Type
if target != nil {
argTypes = append(argTypes, c.getType(target))
}
@ -349,55 +340,75 @@ func (c *checker) resolveOverload(
argTypes = append(argTypes, c.getType(arg))
}
var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().GetOverloads() {
var resultType *types.Type
var checkedRef *ast.ReferenceInfo
for _, overload := range fn.OverloadDecls() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
if c.env.isOverloadDisabled(overload.ID()) {
continue
}
// Ensure the call style for the overload matches.
if (target == nil && overload.GetIsInstanceFunction()) ||
(target != nil && !overload.GetIsInstanceFunction()) {
if (target == nil && overload.IsMemberFunction()) ||
(target != nil && !overload.IsMemberFunction()) {
// not a compatible call style.
continue
}
overloadType := decls.NewFunctionType(overload.ResultType, overload.Params...)
if len(overload.GetTypeParams()) > 0 {
// Alternative type-checking behavior when the logical operators are compacted into
// variadic AST representations.
if fn.Name() == operators.LogicalAnd || fn.Name() == operators.LogicalOr {
checkedRef = ast.NewFunctionReference(overload.ID())
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].GetId(),
c.locationByID(args[i].GetId()),
types.BoolType,
argType)
resultType = types.ErrorType
}
}
if isError(resultType) {
return nil
}
return newResolution(checkedRef, types.BoolType)
}
overloadType := newFunctionType(overload.ResultType(), overload.ArgTypes()...)
typeParams := overload.TypeParams()
if len(typeParams) != 0 {
// Instantiate overload's type with fresh type variables.
substitutions := newMapping()
for _, typePar := range overload.GetTypeParams() {
substitutions.add(decls.NewTypeParamType(typePar), c.newTypeVar())
for _, typePar := range typeParams {
substitutions.add(types.NewTypeParamType(typePar), c.newTypeVar())
}
overloadType = substitute(substitutions, overloadType, false)
}
candidateArgTypes := overloadType.GetFunction().GetArgTypes()
candidateArgTypes := overloadType.Parameters()[1:]
if c.isAssignableList(argTypes, candidateArgTypes) {
if checkedRef == nil {
checkedRef = newFunctionReference(overload.GetOverloadId())
checkedRef = ast.NewFunctionReference(overload.ID())
} else {
checkedRef.OverloadId = append(checkedRef.GetOverloadId(), overload.GetOverloadId())
checkedRef.AddOverload(overload.ID())
}
// First matching overload, determines result type.
fnResultType := substitute(c.mappings, overloadType.GetFunction().GetResultType(), false)
fnResultType := substitute(c.mappings, overloadType.Parameters()[0], false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !proto.Equal(fnResultType, resultType) {
resultType = decls.Dyn
} else if !isDyn(resultType) && !fnResultType.IsExactType(resultType) {
resultType = types.DynType
}
}
}
if resultType == nil {
for i, arg := range argTypes {
argTypes[i] = substitute(c.mappings, arg, true)
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error
c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
@ -406,7 +417,7 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemsType *exprpb.Type
var elemsType *types.Type
optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
@ -419,16 +430,16 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType)
c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(c.location(e), elemsType, elemType)
elemsType = c.joinTypes(e, elemsType, elemType)
}
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemsType = c.newTypeVar()
}
c.setType(e, decls.NewListType(elemsType))
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@ -442,12 +453,12 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
var mapKeyType *exprpb.Type
var mapValueType *exprpb.Type
var mapKeyType *types.Type
var mapValueType *types.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
c.check(key)
mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key))
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
val := ent.GetValue()
c.check(val)
@ -456,50 +467,54 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType)
c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(c.location(val), mapValueType, valType)
mapValueType = c.joinTypes(val, mapValueType, valType)
}
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, decls.NewMapType(mapKeyType, mapValueType))
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
// Determine the type of the message.
messageType := decls.Error
decl := c.env.LookupIdent(msgVal.GetMessageName())
if decl == nil {
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.GetMessageName())
if ident == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), msgVal.GetMessageName())
e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
msgVal.MessageName = decl.GetName()
c.setReference(e, newIdentReference(decl.GetName(), nil))
ident := decl.GetIdent()
identKind := kindOf(ident.GetType())
if identKind != kindError {
if identKind != kindType {
c.errors.notAType(c.location(e), ident.GetType())
typeName := ident.Name()
msgVal.MessageName = typeName
c.setReference(e, ast.NewIdentReference(ident.Name(), nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName())
} else {
messageType = ident.GetType().GetType()
if kindOf(messageType) != kindObject {
c.errors.notAMessageType(c.location(e), messageType)
messageType = decls.Error
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
// In this context, the type is being instantiated by its protobuf name which
// is not ideal or recommended, but some users expect this to work.
if isWellKnownType(resultType) {
typeName = getWellKnownTypeName(resultType)
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
}
if isObjectWellKnownType(messageType) {
c.setType(e, getObjectWellKnownType(messageType))
} else {
c.setType(e, messageType)
}
c.setType(e, resultType)
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
@ -507,10 +522,10 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
value := ent.GetValue()
c.check(value)
fieldType := decls.Error
ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field)
fieldType := types.ErrorType
ft, found := c.lookupFieldType(ent.GetId(), typeName, field)
if found {
fieldType = ft.Type
fieldType = ft
}
valType := c.getType(value)
@ -518,11 +533,11 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType)
c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType)
c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType)
}
}
}
@ -533,36 +548,36 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
c.check(comp.GetAccuInit())
accuType := c.getType(comp.GetAccuInit())
rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false)
var varType *exprpb.Type
var varType *types.Type
switch kindOf(rangeType) {
case kindList:
varType = rangeType.GetListType().GetElemType()
case kindMap:
switch rangeType.Kind() {
case types.ListKind:
varType = rangeType.Parameters()[0]
case types.MapKind:
// Ranges over the keys.
varType = rangeType.GetMapType().GetKeyType()
case kindDyn, kindError, kindTypeParam:
varType = rangeType.Parameters()[0]
case types.DynKind, types.ErrorKind, types.TypeParamKind:
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, rangeType)
c.isAssignable(types.DynType, rangeType)
// Set the range iteration variable to type DYN as well.
varType = decls.Dyn
varType = types.DynType
default:
c.errors.notAComprehensionRange(c.location(comp.GetIterRange()), rangeType)
varType = decls.Error
c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetAccuVar(), accuType))
c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetIterVar(), varType))
c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.GetLoopCondition())
c.assertType(comp.GetLoopCondition(), decls.Bool)
c.assertType(comp.GetLoopCondition(), types.BoolType)
c.check(comp.GetLoopStep())
c.assertType(comp.GetLoopStep(), accuType)
// Exit the loop's block scope before checking the result.
@ -574,9 +589,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(loc common.Location,
previous *exprpb.Type,
current *exprpb.Type) *exprpb.Type {
func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
@ -584,23 +597,23 @@ func (c *checker) joinTypes(loc common.Location,
return mostGeneral(previous, current)
}
if c.dynAggregateLiteralElementTypesEnabled() {
return decls.Dyn
return types.DynType
}
c.errors.typeMismatch(loc, previous, current)
return decls.Error
c.errors.typeMismatch(e.GetId(), c.location(e), previous, current)
return types.ErrorType
}
func (c *checker) dynAggregateLiteralElementTypesEnabled() bool {
return c.env.aggLitElemType == dynElementType
}
func (c *checker) newTypeVar() *exprpb.Type {
func (c *checker) newTypeVar() *types.Type {
id := c.freeTypeVarCounter
c.freeTypeVarCounter++
return decls.NewTypeParamType(fmt.Sprintf("_var%d", id))
return types.NewTypeParamType(fmt.Sprintf("_var%d", id))
}
func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
func (c *checker) isAssignable(t1, t2 *types.Type) bool {
subs := isAssignable(c.mappings, t1, t2)
if subs != nil {
c.mappings = subs
@ -610,7 +623,7 @@ func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
return false
}
func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
subs := isAssignableList(c.mappings, l1, l2)
if subs != nil {
c.mappings = subs
@ -620,57 +633,52 @@ func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
return false
}
func (c *checker) lookupFieldType(l common.Location, messageType string, fieldName string) (*ref.FieldType, bool) {
if _, found := c.env.provider.FindType(messageType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(l, messageType)
return nil, false
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
if ft, found := c.env.provider.FindFieldType(messageType, fieldName); found {
return ft, found
}
c.errors.undefinedField(l, fieldName)
return nil, false
return "", false
}
func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) {
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
func (c *checker) setType(e *exprpb.Expr, t *types.Type) {
if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t)
return
}
c.types[e.GetId()] = t
}
func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {
func (c *checker) getType(e *exprpb.Expr) *types.Type {
return c.types[e.GetId()]
}
func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) {
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r)
func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) {
if old, found := c.references[e.GetId()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r)
return
}
c.references[e.GetId()] = r
}
func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) {
func (c *checker) assertType(e *exprpb.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(c.location(e), t, c.getType(e))
c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e))
}
}
type overloadResolution struct {
Reference *exprpb.Reference
Type *exprpb.Type
Type *types.Type
Reference *ast.ReferenceInfo
}
func newResolution(checkedRef *exprpb.Reference, t *exprpb.Type) *overloadResolution {
func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
return &overloadResolution{
Reference: checkedRef,
Reference: r,
Type: t,
}
}
@ -697,10 +705,56 @@ func (c *checker) locationByID(id int64) common.Location {
return common.NoLocation
}
func newIdentReference(name string, value *exprpb.Constant) *exprpb.Reference {
return &exprpb.Reference{Name: name, Value: value}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {
if _, found := c.env.provider.FindStructType(structType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(exprID, c.locationByID(exprID), structType)
return nil, false
}
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
return ft.Type, found
}
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
return nil, false
}
func newFunctionReference(overloads ...string) *exprpb.Reference {
return &exprpb.Reference{OverloadId: overloads}
func isWellKnownType(t *types.Type) bool {
switch t.Kind() {
case types.AnyKind, types.TimestampKind, types.DurationKind, types.DynKind, types.NullTypeKind:
return true
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind:
return t.IsAssignableType(types.NullType)
case types.ListKind:
return t.Parameters()[0] == types.DynType
case types.MapKind:
return t.Parameters()[0] == types.StringType && t.Parameters()[1] == types.DynType
}
return false
}
func getWellKnownTypeName(t *types.Type) string {
if name, found := wellKnownTypes[t.Kind()]; found {
return name
}
return ""
}
var (
wellKnownTypes = map[types.Kind]string{
types.AnyKind: "google.protobuf.Any",
types.BoolKind: "google.protobuf.BoolValue",
types.BytesKind: "google.protobuf.BytesValue",
types.DoubleKind: "google.protobuf.DoubleValue",
types.DurationKind: "google.protobuf.Duration",
types.DynKind: "google.protobuf.Value",
types.IntKind: "google.protobuf.Int64Value",
types.ListKind: "google.protobuf.ListValue",
types.NullTypeKind: "google.protobuf.NullValue",
types.MapKind: "google.protobuf.Struct",
types.StringKind: "google.protobuf.StringValue",
types.TimestampKind: "google.protobuf.Timestamp",
types.UintKind: "google.protobuf.UInt64Value",
}
)

View File

@ -18,7 +18,9 @@ import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -54,7 +56,7 @@ type AstNode interface {
// The first path element is a variable. All subsequent path elements are one of: field name, '@items', '@keys', '@values'.
Path() []string
// Type returns the deduced type of the AstNode.
Type() *exprpb.Type
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
@ -66,7 +68,7 @@ type AstNode interface {
type astNode struct {
path []string
t *exprpb.Type
t *types.Type
expr *exprpb.Expr
derivedSize *SizeEstimate
}
@ -75,7 +77,7 @@ func (e astNode) Path() []string {
return e.path
}
func (e astNode) Type() *exprpb.Type {
func (e astNode) Type() *types.Type {
return e.t
}
@ -228,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
@ -240,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}
var (
@ -258,9 +264,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
@ -289,6 +296,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
@ -301,15 +309,30 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}
// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate
// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checker,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
@ -317,7 +340,7 @@ func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOpti
return CostEstimate{}, err
}
}
return c.cost(checker.GetExpr()), nil
return c.cost(checker.Expr), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@ -351,10 +374,10 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].GetTypeKind().(type) {
case *exprpb.Type_ListType_:
switch c.checkedAST.TypeMap[iterRange].Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
@ -378,8 +401,8 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
switch kindOf(targetType) {
case kindMap, kindObject, kindTypeParam:
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}
@ -403,8 +426,8 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedExpr.ReferenceMap[e.GetId()]
if ref == nil || len(ref.GetOverloadId()) == 0 {
ref := c.checkedAST.ReferenceMap[e.GetId()]
if ref == nil || len(ref.OverloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
@ -417,7 +440,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
for _, overload := range ref.OverloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
@ -530,7 +553,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}
if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
@ -641,8 +671,8 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *exprpb.Type {
return c.checkedExpr.TypeMap[e.GetId()]
func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.TypeMap[e.GetId()]
}
func (c *coster) getPath(e *exprpb.Expr) []string {
@ -663,22 +693,24 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
if size, ok := c.computedSizes[e.GetId()]; ok {
derivedSize = &size
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
return &astNode{
path: path,
t: c.getType(e),
expr: e,
derivedSize: derivedSize}
}
// isScalar returns true if the given type is known to be of a constant size at
// compile time. isScalar will return false for strings (they are variable-width)
// in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time).
func isScalar(t *exprpb.Type) bool {
switch kindOf(t) {
case kindPrimitive:
if t.GetPrimitive() != exprpb.Type_STRING && t.GetPrimitive() != exprpb.Type_BYTES {
return true
}
case kindWellKnown:
if t.GetWellKnown() == exprpb.Type_DURATION || t.GetWellKnown() == exprpb.Type_TIMESTAMP {
return true
}
func isScalar(t *types.Type) bool {
switch t.Kind() {
case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, types.TimestampKind, types.UintKind:
return true
}
return false
}
var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)

View File

@ -9,7 +9,6 @@ go_library(
name = "go_default_library",
srcs = [
"decls.go",
"scopes.go",
],
importpath = "github.com/google/cel-go/checker/decls",
deps = [

View File

@ -18,17 +18,11 @@ import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type aggregateLiteralElementType int
@ -76,15 +70,15 @@ var (
// which can be used to assist with type-checking.
type Env struct {
container *containers.Container
provider ref.TypeProvider
declarations *decls.Scopes
provider types.Provider
declarations *Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
}
// NewEnv returns a new *Env with the given parameters.
func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...Option) (*Env, error) {
declarations := decls.NewScopes()
func NewEnv(container *containers.Container, provider types.Provider, opts ...Option) (*Env, error) {
declarations := newScopes()
declarations.Push()
envOptions := &options{}
@ -113,24 +107,31 @@ func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...
}, nil
}
// Add adds new Decl protos to the Env.
// Returns an error for identifier redeclarations.
func (e *Env) Add(decls ...*exprpb.Decl) error {
// AddIdents configures the checker with a list of variable declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddIdents(declarations ...*decls.VariableDecl) error {
errMsgs := make([]errorMsg, 0)
for _, decl := range decls {
switch decl.DeclKind.(type) {
case *exprpb.Decl_Ident:
errMsgs = append(errMsgs, e.addIdent(sanitizeIdent(decl)))
case *exprpb.Decl_Function:
errMsgs = append(errMsgs, e.setFunction(sanitizeFunction(decl))...)
}
for _, d := range declarations {
errMsgs = append(errMsgs, e.addIdent(d))
}
return formatError(errMsgs)
}
// AddFunctions configures the checker with a list of function declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddFunctions(declarations ...*decls.FunctionDecl) error {
errMsgs := make([]errorMsg, 0)
for _, d := range declarations {
errMsgs = append(errMsgs, e.setFunction(d)...)
}
return formatError(errMsgs)
}
// LookupIdent returns a Decl proto for typeName as an identifier in the Env.
// Returns nil if no such identifier is found in the Env.
func (e *Env) LookupIdent(name string) *exprpb.Decl {
func (e *Env) LookupIdent(name string) *decls.VariableDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if ident := e.declarations.FindIdent(candidate); ident != nil {
return ident
@ -139,8 +140,8 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import the name as a reference to a message type. If found,
// the declaration is added to the outest (global) scope of the
// environment, so next time we can access it faster.
if t, found := e.provider.FindType(candidate); found {
decl := decls.NewVar(candidate, t)
if t, found := e.provider.FindStructType(candidate); found {
decl := decls.NewVariable(candidate, t)
e.declarations.AddIdent(decl)
return decl
}
@ -148,11 +149,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import this as an enum value by splitting the name in a type prefix and
// the enum inside.
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {
decl := decls.NewIdent(candidate,
decls.Int,
&exprpb.Constant{
ConstantKind: &exprpb.Constant_Int64Value{
Int64Value: int64(enumValue.(types.Int))}})
decl := decls.NewConstant(candidate, types.IntType, enumValue)
e.declarations.AddIdent(decl)
return decl
}
@ -162,7 +159,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// LookupFunction returns a Decl proto for typeName as a function in env.
// Returns nil if no such function is found in env.
func (e *Env) LookupFunction(name string) *exprpb.Decl {
func (e *Env) LookupFunction(name string) *decls.FunctionDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if fn := e.declarations.FindFunction(candidate); fn != nil {
return fn
@ -171,88 +168,46 @@ func (e *Env) LookupFunction(name string) *exprpb.Decl {
return nil
}
// addOverload adds overload to function declaration f.
// Returns one or more errorMsg values if the overload overlaps with an existing overload or macro.
func (e *Env) addOverload(f *exprpb.Decl, overload *exprpb.Decl_FunctionDecl_Overload) []errorMsg {
errMsgs := make([]errorMsg, 0)
function := f.GetFunction()
emptyMappings := newMapping()
overloadFunction := decls.NewFunctionType(overload.GetResultType(),
overload.GetParams()...)
overloadErased := substitute(emptyMappings, overloadFunction, true)
for _, existing := range function.GetOverloads() {
existingFunction := decls.NewFunctionType(existing.GetResultType(), existing.GetParams()...)
existingErased := substitute(emptyMappings, existingFunction, true)
overlap := isAssignable(emptyMappings, overloadErased, existingErased) != nil ||
isAssignable(emptyMappings, existingErased, overloadErased) != nil
if overlap &&
overload.GetIsInstanceFunction() == existing.GetIsInstanceFunction() {
errMsgs = append(errMsgs,
overlappingOverloadError(f.Name,
overload.GetOverloadId(), overloadFunction,
existing.GetOverloadId(), existingFunction))
}
}
for _, macro := range parser.AllMacros {
if macro.Function() == f.Name &&
macro.IsReceiverStyle() == overload.GetIsInstanceFunction() &&
macro.ArgCount() == len(overload.GetParams()) {
errMsgs = append(errMsgs, overlappingMacroError(f.Name, macro.ArgCount()))
}
}
if len(errMsgs) > 0 {
return errMsgs
}
function.Overloads = append(function.GetOverloads(), overload)
return errMsgs
}
// setFunction adds the function Decl to the Env.
// Adds a function decl if one doesn't already exist, then adds all overloads from the Decl.
// If overload overlaps with an existing overload, adds to the errors in the Env instead.
func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
errorMsgs := make([]errorMsg, 0)
overloads := decl.GetFunction().GetOverloads()
current := e.declarations.FindFunction(decl.Name)
if current == nil {
//Add the function declaration without overloads and check the overloads below.
current = decls.NewFunction(decl.Name)
} else {
existingOverloads := map[string]*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range current.GetFunction().GetOverloads() {
existingOverloads[overload.GetOverloadId()] = overload
func (e *Env) setFunction(fn *decls.FunctionDecl) []errorMsg {
errMsgs := make([]errorMsg, 0)
current := e.declarations.FindFunction(fn.Name())
if current != nil {
var err error
current, err = current.Merge(fn)
if err != nil {
return append(errMsgs, errorMsg(err.Error()))
}
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !overloadsEqual(existing, overload) {
newOverloads = append(newOverloads, overload)
} else {
current = fn
}
for _, overload := range current.OverloadDecls() {
for _, macro := range parser.AllMacros {
if macro.Function() == current.Name() &&
macro.IsReceiverStyle() == overload.IsMemberFunction() &&
macro.ArgCount() == len(overload.ArgTypes()) {
errMsgs = append(errMsgs, overlappingMacroError(current.Name(), macro.ArgCount()))
}
}
overloads = newOverloads
if len(newOverloads) == 0 {
return errorMsgs
if len(errMsgs) > 0 {
return errMsgs
}
// Copy on write since we don't know where this original definition came from.
current = proto.Clone(current).(*exprpb.Decl)
}
e.declarations.SetFunction(current)
for _, overload := range overloads {
errorMsgs = append(errorMsgs, e.addOverload(current, overload)...)
}
return errorMsgs
return errMsgs
}
// addIdent adds the Decl to the declarations in the Env.
// Returns a non-empty errorMsg if the identifier is already declared in the scope.
func (e *Env) addIdent(decl *exprpb.Decl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name)
func (e *Env) addIdent(decl *decls.VariableDecl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name())
if current != nil {
if proto.Equal(current, decl) {
if current.DeclarationIsEquivalent(decl) {
return ""
}
return overlappingIdentifierError(decl.Name)
return overlappingIdentifierError(decl.Name())
}
e.declarations.AddIdent(decl)
return ""
@ -264,111 +219,9 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found
}
// overloadsEqual returns whether two overloads have identical signatures.
//
// type parameter names are ignored as they may be specified in any order and have no bearing on overload
// equivalence
func overloadsEqual(o1, o2 *exprpb.Decl_FunctionDecl_Overload) bool {
return o1.GetOverloadId() == o2.GetOverloadId() &&
o1.GetIsInstanceFunction() == o2.GetIsInstanceFunction() &&
paramsEqual(o1.GetParams(), o2.GetParams()) &&
proto.Equal(o1.GetResultType(), o2.GetResultType())
}
// paramsEqual returns whether two lists have equal length and all types are equal
func paramsEqual(p1, p2 []*exprpb.Type) bool {
if len(p1) != len(p2) {
return false
}
for i, a := range p1 {
b := p2[i]
if !proto.Equal(a, b) {
return false
}
}
return true
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {
fn := decl.GetFunction()
// Determine whether the declaration requires replacements from proto-based message type
// references to well-known CEL type references.
var needsSanitizing bool
for _, o := range fn.GetOverloads() {
if isObjectWellKnownType(o.GetResultType()) {
needsSanitizing = true
break
}
for _, p := range o.GetParams() {
if isObjectWellKnownType(p) {
needsSanitizing = true
break
}
}
}
// Early return if the declaration requires no modification.
if !needsSanitizing {
return decl
}
// Sanitize all of the overloads if any overload requires an update to its type references.
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(fn.GetOverloads()))
for i, o := range fn.GetOverloads() {
rt := o.GetResultType()
if isObjectWellKnownType(rt) {
rt = getObjectWellKnownType(rt)
}
params := make([]*exprpb.Type, len(o.GetParams()))
copy(params, o.GetParams())
for j, p := range params {
if isObjectWellKnownType(p) {
params[j] = getObjectWellKnownType(p)
}
}
// If sanitized, replace the overload definition.
if o.IsInstanceFunction {
overloads[i] =
decls.NewInstanceOverload(o.GetOverloadId(), params, rt)
} else {
overloads[i] =
decls.NewOverload(o.GetOverloadId(), params, rt)
}
}
return decls.NewFunction(decl.GetName(), overloads...)
}
// sanitizeIdent replaces the identifier's well-known types referenced by message name with
// references to CEL built-in type instances.
func sanitizeIdent(decl *exprpb.Decl) *exprpb.Decl {
id := decl.GetIdent()
t := id.GetType()
if !isObjectWellKnownType(t) {
return decl
}
return decls.NewIdent(decl.GetName(), getObjectWellKnownType(t), id.GetValue())
}
// isObjectWellKnownType returns true if the input type is an OBJECT type with a message name
// that corresponds the message name of a built-in CEL type.
func isObjectWellKnownType(t *exprpb.Type) bool {
if kindOf(t) != kindObject {
return false
}
_, found := pb.CheckedWellKnowns[t.GetMessageType()]
return found
}
// getObjectWellKnownType returns the built-in CEL type declaration for input type's message name.
func getObjectWellKnownType(t *exprpb.Type) *exprpb.Type {
return pb.CheckedWellKnowns[t.GetMessageType()]
}
// validatedDeclarations returns a reference to the validated variable and function declaration scope stack.
// must be copied before use.
func (e *Env) validatedDeclarations() *decls.Scopes {
func (e *Env) validatedDeclarations() *Scopes {
return e.declarations
}
@ -402,19 +255,6 @@ func overlappingIdentifierError(name string) errorMsg {
return errorMsg(fmt.Sprintf("overlapping identifier for name '%s'", name))
}
func overlappingOverloadError(name string,
overloadID1 string, f1 *exprpb.Type,
overloadID2 string, f2 *exprpb.Type) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping overload for name '%s' (type '%s' with overloadId: '%s' "+
"cannot be distinguished from '%s' with overloadId: '%s')",
name,
FormatCheckedType(f1),
overloadID1,
FormatCheckedType(f2),
overloadID2))
}
func overlappingMacroError(name string, argCount int) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping macro for name '%s' with %d args", name, argCount))

View File

@ -15,82 +15,78 @@
package checker
import (
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// typeErrors is a specialization of Errors.
type typeErrors struct {
*common.Errors
errs *common.Errors
}
func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) {
e.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container)
func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, field, value *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCELType(field), FormatCELType(value))
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) {
e.ReportError(l, "type '%s' does not support field selection", t)
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) undefinedField(l common.Location, field string) {
e.ReportError(l, "undefined field '%s'", field)
func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
signature := formatFunctionDeclType(nil, args, isInstance)
e.errs.ReportErrorAtID(id, l, "found no matching overload for '%s' applied to '%s'", name, signature)
}
func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) {
signature := formatFunction(nil, args, isInstance)
e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature)
func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCELType(t))
}
func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t)
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}
func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s' is not a message type", FormatCheckedType(t))
func (e *typeErrors) notAType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a type", typeName)
}
func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) {
e.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCheckedType(field), FormatCheckedType(value))
func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}
func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) {
e.ReportError(l, "[internal] unexpected failed resolution of '%s'", typeName)
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) {
e.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCheckedType(t))
func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "type '%s' does not support field selection", FormatCELType(t))
}
func (e *typeErrors) typeMismatch(l common.Location, expected *exprpb.Type, actual *exprpb.Type) {
e.ReportError(l, "expected type '%s' but found '%s'",
FormatCheckedType(expected), FormatCheckedType(actual))
func (e *typeErrors) typeMismatch(id int64, l common.Location, expected, actual *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type '%s' but found '%s'",
FormatCELType(expected), FormatCELType(actual))
}
func formatFunction(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += FormatCheckedType(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += FormatCheckedType(arg)
}
result += ")"
if resultType != nil {
result += " -> "
result += FormatCheckedType(resultType)
}
return result
func (e *typeErrors) undefinedField(id int64, l common.Location, field string) {
e.errs.ReportErrorAtID(id, l, "undefined field '%s'", field)
}
func (e *typeErrors) undeclaredReference(id int64, l common.Location, container string, name string) {
e.errs.ReportErrorAtID(id, l, "undeclared reference to '%s' (in container '%s')", name, container)
}
func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
}

216
vendor/github.com/google/cel-go/checker/format.go generated vendored Normal file
View File

@ -0,0 +1,216 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
import (
"fmt"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunctionExprType(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil || t.GetType().GetTypeKind() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(chkdecls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
type formatter func(any) string
// FormatCELType formats a types.Type value to a string representation.
//
// The type formatting is identical to FormatCheckedType.
func FormatCELType(t any) string {
dt := t.(*types.Type)
switch dt.Kind() {
case types.AnyKind:
return "any"
case types.DurationKind:
return "duration"
case types.ErrorKind:
return "!error!"
case types.NullTypeKind:
return "null"
case types.TimestampKind:
return "timestamp"
case types.TypeParamKind:
return dt.TypeName()
case types.OpaqueKind:
if dt.TypeName() == "function" {
// There is no explicit function type in the new types representation, so information like
// whether the function is a member function is absent.
return formatFunctionDeclType(dt.Parameters()[0], dt.Parameters()[1:], false)
}
case types.UnspecifiedKind:
return ""
}
if len(dt.Parameters()) == 0 {
return dt.DeclaredTypeName()
}
paramTypeNames := make([]string, 0, len(dt.Parameters()))
for _, p := range dt.Parameters() {
paramTypeNames = append(paramTypeNames, FormatCELType(p))
}
return fmt.Sprintf("%s(%s)", dt.TypeName(), strings.Join(paramTypeNames, ", "))
}
func formatExprType(t any) string {
if t == nil {
return ""
}
return FormatCheckedType(t.(*exprpb.Type))
}
func formatFunctionExprType(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
return formatFunctionInternal[*exprpb.Type](resultType, argTypes, isInstance, formatExprType)
}
func formatFunctionDeclType(resultType *types.Type, argTypes []*types.Type, isInstance bool) string {
return formatFunctionInternal[*types.Type](resultType, argTypes, isInstance, FormatCELType)
}
func formatFunctionInternal[T any](resultType T, argTypes []T, isInstance bool, format formatter) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += format(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += format(arg)
}
result += ")"
rt := format(resultType)
if rt != "" {
result += " -> "
result += rt
}
return result
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}

View File

@ -15,25 +15,25 @@
package checker
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
type mapping struct {
mapping map[string]*exprpb.Type
mapping map[string]*types.Type
}
func newMapping() *mapping {
return &mapping{
mapping: make(map[string]*exprpb.Type),
mapping: make(map[string]*types.Type),
}
}
func (m *mapping) add(from *exprpb.Type, to *exprpb.Type) {
m.mapping[typeKey(from)] = to
func (m *mapping) add(from, to *types.Type) {
m.mapping[FormatCELType(from)] = to
}
func (m *mapping) find(from *exprpb.Type) (*exprpb.Type, bool) {
if r, found := m.mapping[typeKey(from)]; found {
func (m *mapping) find(from *types.Type) (*types.Type, bool) {
if r, found := m.mapping[FormatCELType(from)]; found {
return r, found
}
return nil, false

View File

@ -14,12 +14,10 @@
package checker
import "github.com/google/cel-go/checker/decls"
type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *decls.Scopes
validatedDeclarations *Scopes
}
// Option is a functional option for configuring the type-checker
@ -34,15 +32,6 @@ func CrossTypeNumericComparisons(enabled bool) Option {
}
}
// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all
// have the same type.
func HomogeneousAggregateLiterals(enabled bool) Option {
return func(opts *options) error {
opts.homogeneousAggregateLiterals = enabled
return nil
}
}
// ValidatedDeclarations provides a references to validated declarations which will be copied
// into new checker instances.
func ValidatedDeclarations(env *Env) Option {

View File

@ -15,6 +15,8 @@
package checker
import (
"sort"
"github.com/google/cel-go/common/debug"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -47,6 +49,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string {
if len(ref.GetOverloadId()) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
if i == 0 {
result += "^"

View File

@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package decls
package checker
import exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
import (
"github.com/google/cel-go/common/decls"
)
// Scopes represents nested Decl sets where the Scopes value contains a Groups containing all
// identifiers in scope and an optional parent representing outer scopes.
@ -25,9 +27,9 @@ type Scopes struct {
scopes *Group
}
// NewScopes creates a new, empty Scopes.
// newScopes creates a new, empty Scopes.
// Some operations can't be safely performed until a Group is added with Push.
func NewScopes() *Scopes {
func newScopes() *Scopes {
return &Scopes{
scopes: newGroup(),
}
@ -35,7 +37,7 @@ func NewScopes() *Scopes {
// Copy creates a copy of the current Scopes values, including a copy of its parent if non-nil.
func (s *Scopes) Copy() *Scopes {
cpy := NewScopes()
cpy := newScopes()
if s == nil {
return cpy
}
@ -66,14 +68,14 @@ func (s *Scopes) Pop() *Scopes {
// AddIdent adds the ident Decl in the current scope.
// Note: If the name collides with an existing identifier in the scope, the Decl is overwritten.
func (s *Scopes) AddIdent(decl *exprpb.Decl) {
s.scopes.idents[decl.Name] = decl
func (s *Scopes) AddIdent(decl *decls.VariableDecl) {
s.scopes.idents[decl.Name()] = decl
}
// FindIdent finds the first ident Decl with a matching name in Scopes, or nil if one cannot be
// found.
// Note: The search is performed from innermost to outermost.
func (s *Scopes) FindIdent(name string) *exprpb.Decl {
func (s *Scopes) FindIdent(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@ -86,7 +88,7 @@ func (s *Scopes) FindIdent(name string) *exprpb.Decl {
// FindIdentInScope finds the first ident Decl with a matching name in the current Scopes value, or
// nil if one does not exist.
// Note: The search is only performed on the current scope and does not search outer scopes.
func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
func (s *Scopes) FindIdentInScope(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@ -95,14 +97,14 @@ func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
// SetFunction adds the function Decl to the current scope.
// Note: Any previous entry for a function in the current scope with the same name is overwritten.
func (s *Scopes) SetFunction(fn *exprpb.Decl) {
s.scopes.functions[fn.Name] = fn
func (s *Scopes) SetFunction(fn *decls.FunctionDecl) {
s.scopes.functions[fn.Name()] = fn
}
// FindFunction finds the first function Decl with a matching name in Scopes.
// The search is performed from innermost to outermost.
// Returns nil if no such function in Scopes.
func (s *Scopes) FindFunction(name string) *exprpb.Decl {
func (s *Scopes) FindFunction(name string) *decls.FunctionDecl {
if fn, found := s.scopes.functions[name]; found {
return fn
}
@ -116,16 +118,16 @@ func (s *Scopes) FindFunction(name string) *exprpb.Decl {
// Contains separate namespaces for identifier and function Decls.
// (Should be named "Scope" perhaps?)
type Group struct {
idents map[string]*exprpb.Decl
functions map[string]*exprpb.Decl
idents map[string]*decls.VariableDecl
functions map[string]*decls.FunctionDecl
}
// copy creates a new Group instance with a shallow copy of the variables and functions.
// If callers need to mutate the exprpb.Decl definitions for a Function, they should copy-on-write.
func (g *Group) copy() *Group {
cpy := &Group{
idents: make(map[string]*exprpb.Decl, len(g.idents)),
functions: make(map[string]*exprpb.Decl, len(g.functions)),
idents: make(map[string]*decls.VariableDecl, len(g.idents)),
functions: make(map[string]*decls.FunctionDecl, len(g.functions)),
}
for n, id := range g.idents {
cpy.idents[n] = id
@ -139,7 +141,7 @@ func (g *Group) copy() *Group {
// newGroup creates a new Group with empty maps for identifiers and functions.
func newGroup() *Group {
return &Group{
idents: make(map[string]*exprpb.Decl),
functions: make(map[string]*exprpb.Decl),
idents: make(map[string]*decls.VariableDecl),
functions: make(map[string]*decls.FunctionDecl),
}
}

View File

@ -15,480 +15,21 @@
package checker
import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
standardDeclarations []*exprpb.Decl
)
func init() {
// Some shortcuts we use when building declarations.
paramA := decls.NewTypeParamType("A")
typeParamAList := []string{"A"}
listOfA := decls.NewListType(paramA)
paramB := decls.NewTypeParamType("B")
typeParamABList := []string{"A", "B"}
mapOfAB := decls.NewMapType(paramA, paramB)
var idents []*exprpb.Decl
for _, t := range []*exprpb.Type{
decls.Int, decls.Uint, decls.Bool,
decls.Double, decls.Bytes, decls.String} {
idents = append(idents,
decls.NewVar(FormatCheckedType(t), decls.NewTypeType(t)))
}
idents = append(idents,
decls.NewVar("list", decls.NewTypeType(listOfA)),
decls.NewVar("map", decls.NewTypeType(mapOfAB)),
decls.NewVar("null_type", decls.NewTypeType(decls.Null)),
decls.NewVar("type", decls.NewTypeType(decls.NewTypeType(nil))))
standardDeclarations = append(standardDeclarations, idents...)
standardDeclarations = append(standardDeclarations, []*exprpb.Decl{
// Booleans
decls.NewFunction(operators.Conditional,
decls.NewParameterizedOverload(overloads.Conditional,
[]*exprpb.Type{decls.Bool, paramA, paramA}, paramA,
typeParamAList)),
decls.NewFunction(operators.LogicalAnd,
decls.NewOverload(overloads.LogicalAnd,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalOr,
decls.NewOverload(overloads.LogicalOr,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalNot,
decls.NewOverload(overloads.LogicalNot,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.NotStrictlyFalse,
decls.NewOverload(overloads.NotStrictlyFalse,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.Equals,
decls.NewParameterizedOverload(overloads.Equals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
decls.NewFunction(operators.NotEquals,
decls.NewParameterizedOverload(overloads.NotEquals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
// Algebra.
decls.NewFunction(operators.Subtract,
decls.NewOverload(overloads.SubtractInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.SubtractUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.SubtractDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.SubtractTimestampTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Duration),
decls.NewOverload(overloads.SubtractTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.SubtractDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Multiply,
decls.NewOverload(overloads.MultiplyInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.MultiplyUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.MultiplyDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Divide,
decls.NewOverload(overloads.DivideInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.DivideUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.DivideDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Modulo,
decls.NewOverload(overloads.ModuloInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.ModuloUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint)),
decls.NewFunction(operators.Add,
decls.NewOverload(overloads.AddInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.AddUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.AddDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.AddString,
[]*exprpb.Type{decls.String, decls.String}, decls.String),
decls.NewOverload(overloads.AddBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bytes),
decls.NewParameterizedOverload(overloads.AddList,
[]*exprpb.Type{listOfA, listOfA}, listOfA,
typeParamAList),
decls.NewOverload(overloads.AddTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationTimestamp,
[]*exprpb.Type{decls.Duration, decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Negate,
decls.NewOverload(overloads.NegateInt64,
[]*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.NegateDouble,
[]*exprpb.Type{decls.Double}, decls.Double)),
// Index.
decls.NewFunction(operators.Index,
decls.NewParameterizedOverload(overloads.IndexList,
[]*exprpb.Type{listOfA, decls.Int}, paramA,
typeParamAList),
decls.NewParameterizedOverload(overloads.IndexMap,
[]*exprpb.Type{mapOfAB, paramA}, paramB,
typeParamABList)),
// Collections.
decls.NewFunction(overloads.Size,
decls.NewInstanceOverload(overloads.SizeStringInst,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.SizeBytesInst,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedInstanceOverload(overloads.SizeListInst,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedInstanceOverload(overloads.SizeMapInst,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList),
decls.NewOverload(overloads.SizeString,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.SizeBytes,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedOverload(overloads.SizeList,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedOverload(overloads.SizeMap,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList)),
decls.NewFunction(operators.In,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Deprecated 'in()' function.
decls.NewFunction(overloads.DeprecatedIn,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Conversions to type.
decls.NewFunction(overloads.TypeConvertType,
decls.NewParameterizedOverload(overloads.TypeConvertType,
[]*exprpb.Type{paramA}, decls.NewTypeType(paramA), typeParamAList)),
// Conversions to int.
decls.NewFunction(overloads.TypeConvertInt,
decls.NewOverload(overloads.IntToInt, []*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.UintToInt, []*exprpb.Type{decls.Uint}, decls.Int),
decls.NewOverload(overloads.DoubleToInt, []*exprpb.Type{decls.Double}, decls.Int),
decls.NewOverload(overloads.StringToInt, []*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.TimestampToInt, []*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewOverload(overloads.DurationToInt, []*exprpb.Type{decls.Duration}, decls.Int)),
// Conversions to uint.
decls.NewFunction(overloads.TypeConvertUint,
decls.NewOverload(overloads.UintToUint, []*exprpb.Type{decls.Uint}, decls.Uint),
decls.NewOverload(overloads.IntToUint, []*exprpb.Type{decls.Int}, decls.Uint),
decls.NewOverload(overloads.DoubleToUint, []*exprpb.Type{decls.Double}, decls.Uint),
decls.NewOverload(overloads.StringToUint, []*exprpb.Type{decls.String}, decls.Uint)),
// Conversions to double.
decls.NewFunction(overloads.TypeConvertDouble,
decls.NewOverload(overloads.DoubleToDouble, []*exprpb.Type{decls.Double}, decls.Double),
decls.NewOverload(overloads.IntToDouble, []*exprpb.Type{decls.Int}, decls.Double),
decls.NewOverload(overloads.UintToDouble, []*exprpb.Type{decls.Uint}, decls.Double),
decls.NewOverload(overloads.StringToDouble, []*exprpb.Type{decls.String}, decls.Double)),
// Conversions to bool.
decls.NewFunction(overloads.TypeConvertBool,
decls.NewOverload(overloads.BoolToBool, []*exprpb.Type{decls.Bool}, decls.Bool),
decls.NewOverload(overloads.StringToBool, []*exprpb.Type{decls.String}, decls.Bool)),
// Conversions to string.
decls.NewFunction(overloads.TypeConvertString,
decls.NewOverload(overloads.StringToString, []*exprpb.Type{decls.String}, decls.String),
decls.NewOverload(overloads.BoolToString, []*exprpb.Type{decls.Bool}, decls.String),
decls.NewOverload(overloads.IntToString, []*exprpb.Type{decls.Int}, decls.String),
decls.NewOverload(overloads.UintToString, []*exprpb.Type{decls.Uint}, decls.String),
decls.NewOverload(overloads.DoubleToString, []*exprpb.Type{decls.Double}, decls.String),
decls.NewOverload(overloads.BytesToString, []*exprpb.Type{decls.Bytes}, decls.String),
decls.NewOverload(overloads.TimestampToString, []*exprpb.Type{decls.Timestamp}, decls.String),
decls.NewOverload(overloads.DurationToString, []*exprpb.Type{decls.Duration}, decls.String)),
// Conversions to bytes.
decls.NewFunction(overloads.TypeConvertBytes,
decls.NewOverload(overloads.BytesToBytes, []*exprpb.Type{decls.Bytes}, decls.Bytes),
decls.NewOverload(overloads.StringToBytes, []*exprpb.Type{decls.String}, decls.Bytes)),
// Conversions to timestamps.
decls.NewFunction(overloads.TypeConvertTimestamp,
decls.NewOverload(overloads.TimestampToTimestamp,
[]*exprpb.Type{decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.StringToTimestamp,
[]*exprpb.Type{decls.String}, decls.Timestamp),
decls.NewOverload(overloads.IntToTimestamp,
[]*exprpb.Type{decls.Int}, decls.Timestamp)),
// Conversions to durations.
decls.NewFunction(overloads.TypeConvertDuration,
decls.NewOverload(overloads.DurationToDuration,
[]*exprpb.Type{decls.Duration}, decls.Duration),
decls.NewOverload(overloads.StringToDuration,
[]*exprpb.Type{decls.String}, decls.Duration),
decls.NewOverload(overloads.IntToDuration,
[]*exprpb.Type{decls.Int}, decls.Duration)),
// Conversions to Dyn.
decls.NewFunction(overloads.TypeConvertDyn,
decls.NewParameterizedOverload(overloads.ToDyn,
[]*exprpb.Type{paramA}, decls.Dyn,
typeParamAList)),
// String functions.
decls.NewFunction(overloads.Contains,
decls.NewInstanceOverload(overloads.ContainsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.EndsWith,
decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches,
decls.NewOverload(overloads.Matches,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith,
decls.NewInstanceOverload(overloads.StartsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
// Date/time functions.
decls.NewFunction(overloads.TimeGetFullYear,
decls.NewInstanceOverload(overloads.TimestampToYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetMonth,
decls.NewInstanceOverload(overloads.TimestampToMonth,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMonthWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfYear,
decls.NewInstanceOverload(overloads.TimestampToDayOfYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfMonth,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDate,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfWeek,
decls.NewInstanceOverload(overloads.TimestampToDayOfWeek,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfWeekWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetHours,
decls.NewInstanceOverload(overloads.TimestampToHours,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToHoursWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToHours,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMinutes,
decls.NewInstanceOverload(overloads.TimestampToMinutes,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMinutesWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMinutes,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetSeconds,
decls.NewInstanceOverload(overloads.TimestampToSeconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToSecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToSeconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMilliseconds,
decls.NewInstanceOverload(overloads.TimestampToMilliseconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMillisecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMilliseconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
// Relations.
decls.NewFunction(operators.Less,
decls.NewOverload(overloads.LessBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.LessEquals,
decls.NewOverload(overloads.LessEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Greater,
decls.NewOverload(overloads.GreaterBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.GreaterEquals,
decls.NewOverload(overloads.GreaterEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
}...)
// StandardFunctions returns the Decls for all functions in the evaluator.
//
// Deprecated: prefer stdlib.FunctionExprDecls()
func StandardFunctions() []*exprpb.Decl {
return stdlib.FunctionExprDecls()
}
// StandardDeclarations returns the Decls for all functions and constants in the evaluator.
func StandardDeclarations() []*exprpb.Decl {
return standardDeclarations
// StandardTypes returns the set of type identifiers for standard library types.
//
// Deprecated: prefer stdlib.TypeExprDecls()
func StandardTypes() []*exprpb.Decl {
return stdlib.TypeExprDecls()
}

View File

@ -15,154 +15,54 @@
package checker
import (
"fmt"
"strings"
"github.com/google/cel-go/checker/decls"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunction(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(decls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
// isDyn returns true if the input t is either type DYN or a well-known ANY message.
func isDyn(t *exprpb.Type) bool {
func isDyn(t *types.Type) bool {
// Note: object type values that are well-known and map to a DYN value in practice
// are sanitized prior to being added to the environment.
switch kindOf(t) {
case kindDyn:
switch t.Kind() {
case types.DynKind, types.AnyKind:
return true
case kindWellKnown:
return t.GetWellKnown() == exprpb.Type_ANY
default:
return false
}
}
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool {
func isDynOrError(t *types.Type) bool {
return isError(t) || isDyn(t)
}
func isError(t *exprpb.Type) bool {
return kindOf(t) == kindError
func isError(t *types.Type) bool {
return t.Kind() == types.ErrorKind
}
func isOptional(t *exprpb.Type) bool {
if kindOf(t) == kindAbstract {
at := t.GetAbstractType()
return at.GetName() == "optional"
func isOptional(t *types.Type) bool {
if t.Kind() == types.OpaqueKind {
return t.TypeName() == "optional"
}
return false
}
func maybeUnwrapOptional(t *exprpb.Type) (*exprpb.Type, bool) {
func maybeUnwrapOptional(t *types.Type) (*types.Type, bool) {
if isOptional(t) {
at := t.GetAbstractType()
return at.GetParameterTypes()[0], true
return t.Parameters()[0], true
}
return t, false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
return "", false
}
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
// A type is less specific if it matches the other type using the DYN type.
func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
kind1, kind2 := kindOf(t1), kindOf(t2)
func isEqualOrLessSpecific(t1, t2 *types.Type) bool {
kind1, kind2 := t1.Kind(), t2.Kind()
// The first type is less specific.
if isDyn(t1) || kind1 == kindTypeParam {
if isDyn(t1) || kind1 == types.TypeParamKind {
return true
}
// The first type is not less specific.
if isDyn(t2) || kind2 == kindTypeParam {
if isDyn(t2) || kind2 == types.TypeParamKind {
return false
}
// Types must be of the same kind to be equal.
@ -173,38 +73,34 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
// With limited exceptions for ANY and JSON values, the types must agree and be equivalent in
// order to return true.
switch kind1 {
case kindAbstract:
a1 := t1.GetAbstractType()
a2 := t2.GetAbstractType()
if a1.GetName() != a2.GetName() ||
len(a1.GetParameterTypes()) != len(a2.GetParameterTypes()) {
case types.OpaqueKind:
if t1.TypeName() != t2.TypeName() ||
len(t1.Parameters()) != len(t2.Parameters()) {
return false
}
for i, p1 := range a1.GetParameterTypes() {
if !isEqualOrLessSpecific(p1, a2.GetParameterTypes()[i]) {
for i, p1 := range t1.Parameters() {
if !isEqualOrLessSpecific(p1, t2.Parameters()[i]) {
return false
}
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
m1 := t1.GetMapType()
m2 := t2.GetMapType()
return isEqualOrLessSpecific(m1.GetKeyType(), m2.GetKeyType()) &&
isEqualOrLessSpecific(m1.GetValueType(), m2.GetValueType())
case kindType:
case types.ListKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0])
case types.MapKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0]) &&
isEqualOrLessSpecific(t1.Parameters()[1], t2.Parameters()[1])
case types.TypeKind:
return true
default:
return proto.Equal(t1, t2)
return t1.IsExactType(t2)
}
}
// / internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
func internalIsAssignable(m *mapping, t1, t2 *types.Type) bool {
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind2 == kindTypeParam {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind2 == types.TypeParamKind {
// If t2 is a valid type substitution for t1, return true.
valid, t2HasSub := isValidTypeSubstitution(m, t1, t2)
if valid {
@ -217,7 +113,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
// Otherwise, fall through to check whether t1 is a possible substitution for t2.
}
if kind1 == kindTypeParam {
if kind1 == types.TypeParamKind {
// Return whether t1 is a valid substitution for t2. If not, do no additional checks as the
// possible type substitutions have been searched in both directions.
valid, _ := isValidTypeSubstitution(m, t2, t1)
@ -228,40 +124,25 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
if isDynOrError(t1) || isDynOrError(t2) {
return true
}
// Preserve the nullness checks of the legacy type-checker.
if kind1 == types.NullTypeKind {
return internalIsAssignableNull(t2)
}
if kind2 == types.NullTypeKind {
return internalIsAssignableNull(t1)
}
// Test for when the types do not need to agree, but are more specific than dyn.
switch kind1 {
case kindNull:
return internalIsAssignableNull(t2)
case kindPrimitive:
return internalIsAssignablePrimitive(t1.GetPrimitive(), t2)
case kindWrapper:
return internalIsAssignable(m, decls.NewPrimitiveType(t1.GetWrapper()), t2)
default:
if kind1 != kind2 {
return false
}
}
// Test for when the types must agree.
switch kind1 {
// ERROR, TYPE_PARAM, and DYN handled above.
case kindAbstract:
return internalIsAssignableAbstractType(m, t1.GetAbstractType(), t2.GetAbstractType())
case kindFunction:
return internalIsAssignableFunction(m, t1.GetFunction(), t2.GetFunction())
case kindList:
return internalIsAssignable(m, t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
return internalIsAssignableMap(m, t1.GetMapType(), t2.GetMapType())
case kindObject:
return t1.GetMessageType() == t2.GetMessageType()
case kindType:
// A type is a type is a type, any additional parameterization of the
// type cannot affect method resolution or assignability.
return true
case kindWellKnown:
return t1.GetWellKnown() == t2.GetWellKnown()
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind,
types.AnyKind, types.DurationKind, types.TimestampKind,
types.StructKind:
return t1.IsAssignableType(t2)
case types.TypeKind:
return kind2 == types.TypeKind
case types.OpaqueKind, types.ListKind, types.MapKind:
return t1.Kind() == t2.Kind() && t1.TypeName() == t2.TypeName() &&
internalIsAssignableList(m, t1.Parameters(), t2.Parameters())
default:
return false
}
@ -274,16 +155,16 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {
func isValidTypeSubstitution(m *mapping, t1, t2 *types.Type) (valid, hasSub bool) {
// Early return if the t1 and t2 are the same instance.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind1 == kind2 && (t1 == t2 || proto.Equal(t1, t2)) {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind1 == kind2 && t1.IsExactType(t2) {
return true, true
}
if t2Sub, found := m.find(t2); found {
// Early return if t1 and t2Sub are the same instance as otherwise the mapping
// might mark a type as being a subtitution for itself.
if kind1 == kindOf(t2Sub) && (t1 == t2Sub || proto.Equal(t1, t2Sub)) {
if kind1 == t2Sub.Kind() && t1.IsExactType(t2Sub) {
return true, true
}
// If the types are compatible, pick the more general type and return true
@ -305,28 +186,10 @@ func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub boo
return false, false
}
// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, a2 *exprpb.Type_AbstractType) bool {
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}
// internalIsAssignableFunction returns true if the function return type and arg types are
// assignable.
func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool {
f1ArgTypes := flattenFunctionTypes(f1)
f2ArgTypes := flattenFunctionTypes(f2)
if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) {
return true
}
return false
}
// internalIsAssignableList returns true if the element types at each index in the list are
// assignable from l1[i] to l2[i]. The list lengths must also agree for the lists to be
// assignable.
func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func internalIsAssignableList(m *mapping, l1, l2 []*types.Type) bool {
if len(l1) != len(l2) {
return false
}
@ -338,41 +201,22 @@ func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type)
return true
}
// internalIsAssignableMap returns true if map m1 may be assigned to map m2.
func internalIsAssignableMap(m *mapping, m1 *exprpb.Type_MapType, m2 *exprpb.Type_MapType) bool {
if internalIsAssignableList(m,
[]*exprpb.Type{m1.GetKeyType(), m1.GetValueType()},
[]*exprpb.Type{m2.GetKeyType(), m2.GetValueType()}) {
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *types.Type) bool {
return isLegacyNullable(t) || t.IsAssignableType(types.NullType)
}
// isLegacyNullable preserves the null-ness compatibility of the original type-checker implementation.
func isLegacyNullable(t *types.Type) bool {
switch t.Kind() {
case types.OpaqueKind, types.StructKind, types.AnyKind, types.DurationKind, types.TimestampKind:
return true
}
return false
}
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *exprpb.Type) bool {
switch kindOf(t) {
case kindAbstract, kindObject, kindNull, kindWellKnown, kindWrapper:
return true
default:
return false
}
}
// internalIsAssignablePrimitive returns true if the target type is the same or if it is a wrapper
// for the primitive type.
func internalIsAssignablePrimitive(p exprpb.Type_PrimitiveType, target *exprpb.Type) bool {
switch kindOf(target) {
case kindPrimitive:
return p == target.GetPrimitive()
case kindWrapper:
return p == target.GetWrapper()
default:
return false
}
}
// isAssignable returns an updated type substitution mapping if t1 is assignable to t2.
func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
func isAssignable(m *mapping, t1, t2 *types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignable(mCopy, t1, t2) {
return mCopy
@ -381,7 +225,7 @@ func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
}
// isAssignableList returns an updated type substitution mapping if l1 is assignable to l2.
func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping {
func isAssignableList(m *mapping, l1, l2 []*types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignableList(mCopy, l1, l2) {
return mCopy
@ -389,44 +233,8 @@ func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping
return nil
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}
// mostGeneral returns the more general of two types which are known to unify.
func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
func mostGeneral(t1, t2 *types.Type) *types.Type {
if isEqualOrLessSpecific(t1, t2) {
return t1
}
@ -436,32 +244,25 @@ func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
// notReferencedIn checks whether the type doesn't appear directly or transitively within the other
// type. This is a standard requirement for type unification, commonly referred to as the "occurs
// check".
func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
if proto.Equal(t, withinType) {
func notReferencedIn(m *mapping, t, withinType *types.Type) bool {
if t.IsExactType(withinType) {
return false
}
withinKind := kindOf(withinType)
withinKind := withinType.Kind()
switch withinKind {
case kindTypeParam:
case types.TypeParamKind:
wtSub, found := m.find(withinType)
if !found {
return true
}
return notReferencedIn(m, t, wtSub)
case kindAbstract:
for _, pt := range withinType.GetAbstractType().GetParameterTypes() {
case types.OpaqueKind, types.ListKind, types.MapKind:
for _, pt := range withinType.Parameters() {
if !notReferencedIn(m, t, pt) {
return false
}
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().GetElemType())
case kindMap:
mt := withinType.GetMapType()
return notReferencedIn(m, t, mt.GetKeyType()) && notReferencedIn(m, t, mt.GetValueType())
case kindWrapper:
return notReferencedIn(m, t, decls.NewPrimitiveType(withinType.GetWrapper()))
default:
return true
}
@ -469,39 +270,25 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
// substitute replaces all direct and indirect occurrences of bound type parameters. Unbound type
// parameters are replaced by DYN if typeParamToDyn is true.
func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
func substitute(m *mapping, t *types.Type, typeParamToDyn bool) *types.Type {
if tSub, found := m.find(t); found {
return substitute(m, tSub, typeParamToDyn)
}
kind := kindOf(t)
if typeParamToDyn && kind == kindTypeParam {
return decls.Dyn
kind := t.Kind()
if typeParamToDyn && kind == types.TypeParamKind {
return types.DynType
}
switch kind {
case kindAbstract:
at := t.GetAbstractType()
params := make([]*exprpb.Type, len(at.GetParameterTypes()))
for i, p := range at.GetParameterTypes() {
params[i] = substitute(m, p, typeParamToDyn)
}
return decls.NewAbstractType(at.GetName(), params...)
case kindFunction:
fn := t.GetFunction()
rt := substitute(m, fn.ResultType, typeParamToDyn)
args := make([]*exprpb.Type, len(fn.GetArgTypes()))
for i, a := range fn.ArgTypes {
args[i] = substitute(m, a, typeParamToDyn)
}
return decls.NewFunctionType(rt, args...)
case kindList:
return decls.NewListType(substitute(m, t.GetListType().GetElemType(), typeParamToDyn))
case kindMap:
mt := t.GetMapType()
return decls.NewMapType(substitute(m, mt.GetKeyType(), typeParamToDyn),
substitute(m, mt.GetValueType(), typeParamToDyn))
case kindType:
if t.GetType() != nil {
return decls.NewTypeType(substitute(m, t.GetType(), typeParamToDyn))
case types.OpaqueKind:
return types.NewOpaqueType(t.TypeName(), substituteParams(m, t.Parameters(), typeParamToDyn)...)
case types.ListKind:
return types.NewListType(substitute(m, t.Parameters()[0], typeParamToDyn))
case types.MapKind:
return types.NewMapType(substitute(m, t.Parameters()[0], typeParamToDyn),
substitute(m, t.Parameters()[1], typeParamToDyn))
case types.TypeKind:
if len(t.Parameters()) > 0 {
return types.NewTypeTypeWithParam(substitute(m, t.Parameters()[0], typeParamToDyn))
}
return t
default:
@ -509,21 +296,14 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
}
}
func typeKey(t *exprpb.Type) string {
return FormatCheckedType(t)
func substituteParams(m *mapping, typeParams []*types.Type, typeParamToDyn bool) []*types.Type {
subParams := make([]*types.Type, len(typeParams))
for i, tp := range typeParams {
subParams[i] = substitute(m, tp, typeParamToDyn)
}
return subParams
}
// flattenFunctionTypes takes a function with arg types T1, T2, ..., TN and result type TR
// and returns a slice containing {T1, T2, ..., TN, TR}.
func flattenFunctionTypes(f *exprpb.Type_FunctionType) []*exprpb.Type {
argTypes := f.GetArgTypes()
if len(argTypes) == 0 {
return []*exprpb.Type{f.GetResultType()}
}
flattend := make([]*exprpb.Type, len(argTypes)+1, len(argTypes)+1)
for i, at := range argTypes {
flattend[i] = at
}
flattend[len(argTypes)] = f.GetResultType()
return flattend
func newFunctionType(resultType *types.Type, argTypes ...*types.Type) *types.Type {
return types.NewOpaqueType("function", append([]*types.Type{resultType}, argTypes...)...)
}

52
vendor/github.com/google/cel-go/common/ast/BUILD.bazel generated vendored Normal file
View File

@ -0,0 +1,52 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = [
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//interpreter:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"ast.go",
"expr.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"expr_test.go",
],
embed = [
":go_default_library",
],
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

226
vendor/github.com/google/cel-go/common/ast/ast.go generated vendored Normal file
View File

@ -0,0 +1,226 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ast declares data structures useful for parsed and checked abstract syntax trees
package ast
import (
"fmt"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information.
type CheckedAST struct {
Expr *exprpb.Expr
SourceInfo *exprpb.SourceInfo
TypeMap map[int64]*types.Type
ReferenceMap map[int64]*ReferenceInfo
}
// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf.
func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap))
for id, ref := range ast.ReferenceMap {
r, err := ReferenceInfoToReferenceExpr(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap))
for id, typ := range ast.TypeMap {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &exprpb.CheckedExpr{
Expr: ast.Expr,
SourceInfo: ast.SourceInfo,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance.
func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ReferenceExprToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &CheckedAST{
Expr: checked.GetExpr(),
SourceInfo: checked.GetSourceInfo(),
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// ReferenceInfo contains a CEL native representation of an identifier reference which may refer to
// either a qualified identifier name, a set of overload ids, or a constant value from an enum.
type ReferenceInfo struct {
Name string
OverloadIDs []string
Value ref.Val
}
// NewIdentReference creates a ReferenceInfo instance for an identifier with an optional constant value.
func NewIdentReference(name string, value ref.Val) *ReferenceInfo {
return &ReferenceInfo{Name: name, Value: value}
}
// NewFunctionReference creates a ReferenceInfo instance for a set of function overloads.
func NewFunctionReference(overloads ...string) *ReferenceInfo {
info := &ReferenceInfo{}
for _, id := range overloads {
info.AddOverload(id)
}
return info
}
// AddOverload appends a function overload ID to the ReferenceInfo.
func (r *ReferenceInfo) AddOverload(overloadID string) {
for _, id := range r.OverloadIDs {
if id == overloadID {
return
}
}
r.OverloadIDs = append(r.OverloadIDs, overloadID)
}
// Equals returns whether two references are identical to each other.
func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
if r.Name != other.Name {
return false
}
if len(r.OverloadIDs) != len(other.OverloadIDs) {
return false
}
if len(r.OverloadIDs) != 0 {
overloadMap := make(map[string]struct{}, len(r.OverloadIDs))
for _, id := range r.OverloadIDs {
overloadMap[id] = struct{}{}
}
for _, id := range other.OverloadIDs {
_, found := overloadMap[id]
if !found {
return false
}
}
}
if r.Value == nil && other.Value == nil {
return true
}
if r.Value == nil && other.Value != nil ||
r.Value != nil && other.Value == nil ||
r.Value.Equal(other.Value) != types.True {
return false
}
return true
}
// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
}
// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

709
vendor/github.com/google/cel-go/common/ast/expr.go generated vendored Normal file
View File

@ -0,0 +1,709 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// ExprKind represents the expression node kind.
type ExprKind int
const (
// UnspecifiedKind represents an unset expression with no specified properties.
UnspecifiedKind ExprKind = iota
// LiteralKind represents a primitive scalar literal.
LiteralKind
// IdentKind represents a simple variable, constant, or type identifier.
IdentKind
// SelectKind represents a field selection expression.
SelectKind
// CallKind represents a function call.
CallKind
// ListKind represents a list literal expression.
ListKind
// MapKind represents a map literal expression.
MapKind
// StructKind represents a struct literal expression.
StructKind
// ComprehensionKind represents a comprehension expression generated by a macro.
ComprehensionKind
)
// NavigateCheckedAST converts a CheckedAST to a NavigableExpr
func NavigateCheckedAST(ast *CheckedAST) NavigableExpr {
return newNavigableExpr(nil, ast.Expr, ast.TypeMap)
}
// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match.
//
// This function type should be use with the `Match` and `MatchList` calls.
type ExprMatcher func(NavigableExpr) bool
// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr
// is comprised of all constant values, such as a simple literal or even list and map literal.
func ConstantValueMatcher() ExprMatcher {
return matchIsConstantValue
}
// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches
// the specified `kind`.
func KindMatcher(kind ExprKind) ExprMatcher {
return func(e NavigableExpr) bool {
return e.Kind() == kind
}
}
// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
return false
}
return e.AsCall().FunctionName() == funcName
}
}
// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list.
//
// Such a result would work well with subsequent MatchList calls.
func AllMatcher() ExprMatcher {
return func(NavigableExpr) bool {
return true
}
}
// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the
// descendants which match.
func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr {
return matchListInternal([]NavigableExpr{expr}, matcher, true)
}
// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a
// subset of NavigableExpr values which match.
func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr {
visit := make([]NavigableExpr, len(exprs))
copy(visit, exprs)
return matchListInternal(visit, matcher, false)
}
func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr {
var matched []NavigableExpr
for len(visit) != 0 {
e := visit[0]
if matcher(e) {
matched = append(matched, e)
}
if visitDescendants {
visit = append(visit[1:], e.Children()...)
} else {
visit = visit[1:]
}
}
return matched
}
func matchIsConstantValue(e NavigableExpr) bool {
if e.Kind() == LiteralKind {
return true
}
if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind {
for _, child := range e.Children() {
if !matchIsConstantValue(child) {
return false
}
}
return true
}
return false
}
// NavigableExpr represents the base navigable expression value.
//
// Depending on the `Kind()` value, the NavigableExpr may be converted to a concrete expression types
// as indicated by the `As<Kind>` methods.
//
// NavigableExpr values and their concrete expression types should be nil-safe. Conversion of an expr
// to the wrong kind should produce a nil value.
type NavigableExpr interface {
// ID of the expression as it appears in the AST
ID() int64
// Kind of the expression node. See ExprKind for the valid enum values.
Kind() ExprKind
// Type of the expression node.
Type() *types.Type
// Parent returns the parent expression node, if one exists.
Parent() (NavigableExpr, bool)
// Children returns a list of child expression nodes.
Children() []NavigableExpr
// ToExpr adapts this NavigableExpr to a protobuf representation.
ToExpr() *exprpb.Expr
// AsCall adapts the expr into a NavigableCallExpr
//
// The Kind() must be equal to a CallKind for the conversion to be well-defined.
AsCall() NavigableCallExpr
// AsComprehension adapts the expr into a NavigableComprehensionExpr.
//
// The Kind() must be equal to a ComprehensionKind for the conversion to be well-defined.
AsComprehension() NavigableComprehensionExpr
// AsIdent adapts the expr into an identifier string.
//
// The Kind() must be equal to an IdentKind for the conversion to be well-defined.
AsIdent() string
// AsLiteral adapts the expr into a constant ref.Val.
//
// The Kind() must be equal to a LiteralKind for the conversion to be well-defined.
AsLiteral() ref.Val
// AsList adapts the expr into a NavigableListExpr.
//
// The Kind() must be equal to a ListKind for the conversion to be well-defined.
AsList() NavigableListExpr
// AsMap adapts the expr into a NavigableMapExpr.
//
// The Kind() must be equal to a MapKind for the conversion to be well-defined.
AsMap() NavigableMapExpr
// AsSelect adapts the expr into a NavigableSelectExpr.
//
// The Kind() must be equal to a SelectKind for the conversion to be well-defined.
AsSelect() NavigableSelectExpr
// AsStruct adapts the expr into a NavigableStructExpr.
//
// The Kind() must be equal to a StructKind for the conversion to be well-defined.
AsStruct() NavigableStructExpr
// marker interface method
isNavigable()
}
// NavigableCallExpr defines an interface for inspecting a function call and its arugments.
type NavigableCallExpr interface {
// FunctionName returns the name of the function.
FunctionName() string
// Target returns the target of the expression if one is present.
Target() NavigableExpr
// Args returns the list of call arguments, excluding the target.
Args() []NavigableExpr
// ReturnType returns the result type of the call.
ReturnType() *types.Type
// marker interface method
isNavigable()
}
// NavigableListExpr defines an interface for inspecting a list literal expression.
type NavigableListExpr interface {
// Elements returns the list elements as navigable expressions.
Elements() []NavigableExpr
// OptionalIndicies returns the list of optional indices in the list literal.
OptionalIndices() []int32
// Size returns the number of elements in the list.
Size() int
// marker interface method
isNavigable()
}
// NavigableSelectExpr defines an interface for inspecting a select expression.
type NavigableSelectExpr interface {
// Operand returns the selection operand expression.
Operand() NavigableExpr
// FieldName returns the field name being selected from the operand.
FieldName() string
// IsTestOnly indicates whether the select expression is a presence test generated by a macro.
IsTestOnly() bool
// marker interface method
isNavigable()
}
// NavigableMapExpr defines an interface for inspecting a map expression.
type NavigableMapExpr interface {
// Entries returns the map key value pairs as NavigableEntry values.
Entries() []NavigableEntry
// Size returns the number of entries in the map.
Size() int
// marker interface method
isNavigable()
}
// NavigableEntry defines an interface for inspecting a map entry.
type NavigableEntry interface {
// Key returns the map entry key expression.
Key() NavigableExpr
// Value returns the map entry value expression.
Value() NavigableExpr
// IsOptional returns whether the entry is optional.
IsOptional() bool
// marker interface method
isNavigable()
}
// NavigableStructExpr defines an interfaces for inspecting a struct and its field initializers.
type NavigableStructExpr interface {
// TypeName returns the struct type name.
TypeName() string
// Fields returns the set of field initializers in the struct expression as NavigableField values.
Fields() []NavigableField
// marker interface method
isNavigable()
}
// NavigableField defines an interface for inspecting a struct field initialization.
type NavigableField interface {
// FieldName returns the name of the field.
FieldName() string
// Value returns the field initialization expression.
Value() NavigableExpr
// IsOptional returns whether the field is optional.
IsOptional() bool
// marker interface method
isNavigable()
}
// NavigableComprehensionExpr defines an interface for inspecting a comprehension expression.
type NavigableComprehensionExpr interface {
// IterRange returns the iteration range expression.
IterRange() NavigableExpr
// IterVar returns the iteration variable name.
IterVar() string
// AccuVar returns the accumulation variable name.
AccuVar() string
// AccuInit returns the accumulation variable initialization expression.
AccuInit() NavigableExpr
// LoopCondition returns the loop condition expression.
LoopCondition() NavigableExpr
// LoopStep returns the loop step expression.
LoopStep() NavigableExpr
// Result returns the comprehension result expression.
Result() NavigableExpr
// marker interface method
isNavigable()
}
func newNavigableExpr(parent NavigableExpr, expr *exprpb.Expr, typeMap map[int64]*types.Type) NavigableExpr {
kind, factory := kindOf(expr)
nav := &navigableExprImpl{
parent: parent,
kind: kind,
expr: expr,
typeMap: typeMap,
createChildren: factory,
}
return nav
}
type navigableExprImpl struct {
parent NavigableExpr
kind ExprKind
expr *exprpb.Expr
typeMap map[int64]*types.Type
createChildren childFactory
}
func (nav *navigableExprImpl) ID() int64 {
return nav.ToExpr().GetId()
}
func (nav *navigableExprImpl) Kind() ExprKind {
return nav.kind
}
func (nav *navigableExprImpl) Type() *types.Type {
if t, found := nav.typeMap[nav.ID()]; found {
return t
}
return types.DynType
}
func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) {
if nav.parent != nil {
return nav.parent, true
}
return nil, false
}
func (nav *navigableExprImpl) Children() []NavigableExpr {
return nav.createChildren(nav)
}
func (nav *navigableExprImpl) ToExpr() *exprpb.Expr {
return nav.expr
}
func (nav *navigableExprImpl) AsCall() NavigableCallExpr {
return navigableCallImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsComprehension() NavigableComprehensionExpr {
return navigableComprehensionImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsIdent() string {
return nav.ToExpr().GetIdentExpr().GetName()
}
func (nav *navigableExprImpl) AsLiteral() ref.Val {
if nav.Kind() != LiteralKind {
return nil
}
val, err := ConstantToVal(nav.ToExpr().GetConstExpr())
if err != nil {
panic(err)
}
return val
}
func (nav *navigableExprImpl) AsList() NavigableListExpr {
return navigableListImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsMap() NavigableMapExpr {
return navigableMapImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsSelect() NavigableSelectExpr {
return navigableSelectImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsStruct() NavigableStructExpr {
return navigableStructImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) createChild(e *exprpb.Expr) NavigableExpr {
return newNavigableExpr(nav, e, nav.typeMap)
}
func (nav *navigableExprImpl) isNavigable() {}
type navigableCallImpl struct {
*navigableExprImpl
}
func (call navigableCallImpl) FunctionName() string {
return call.ToExpr().GetCallExpr().GetFunction()
}
func (call navigableCallImpl) Target() NavigableExpr {
t := call.ToExpr().GetCallExpr().GetTarget()
if t != nil {
return call.createChild(t)
}
return nil
}
func (call navigableCallImpl) Args() []NavigableExpr {
args := call.ToExpr().GetCallExpr().GetArgs()
navArgs := make([]NavigableExpr, len(args))
for i, a := range args {
navArgs[i] = call.createChild(a)
}
return navArgs
}
func (call navigableCallImpl) ReturnType() *types.Type {
return call.Type()
}
type navigableComprehensionImpl struct {
*navigableExprImpl
}
func (comp navigableComprehensionImpl) IterRange() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetIterRange())
}
func (comp navigableComprehensionImpl) IterVar() string {
return comp.ToExpr().GetComprehensionExpr().GetIterVar()
}
func (comp navigableComprehensionImpl) AccuVar() string {
return comp.ToExpr().GetComprehensionExpr().GetAccuVar()
}
func (comp navigableComprehensionImpl) AccuInit() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetAccuInit())
}
func (comp navigableComprehensionImpl) LoopCondition() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopCondition())
}
func (comp navigableComprehensionImpl) LoopStep() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopStep())
}
func (comp navigableComprehensionImpl) Result() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetResult())
}
type navigableListImpl struct {
*navigableExprImpl
}
func (l navigableListImpl) Elements() []NavigableExpr {
return l.Children()
}
func (l navigableListImpl) OptionalIndices() []int32 {
return l.ToExpr().GetListExpr().GetOptionalIndices()
}
func (l navigableListImpl) Size() int {
return len(l.ToExpr().GetListExpr().GetElements())
}
type navigableMapImpl struct {
*navigableExprImpl
}
func (m navigableMapImpl) Entries() []NavigableEntry {
mapExpr := m.ToExpr().GetStructExpr()
entries := make([]NavigableEntry, len(mapExpr.GetEntries()))
for i, e := range mapExpr.GetEntries() {
entries[i] = navigableEntryImpl{
key: m.createChild(e.GetMapKey()),
val: m.createChild(e.GetValue()),
isOpt: e.GetOptionalEntry(),
}
}
return entries
}
func (m navigableMapImpl) Size() int {
return len(m.ToExpr().GetStructExpr().GetEntries())
}
type navigableEntryImpl struct {
key NavigableExpr
val NavigableExpr
isOpt bool
}
func (e navigableEntryImpl) Key() NavigableExpr {
return e.key
}
func (e navigableEntryImpl) Value() NavigableExpr {
return e.val
}
func (e navigableEntryImpl) IsOptional() bool {
return e.isOpt
}
func (e navigableEntryImpl) isNavigable() {}
type navigableSelectImpl struct {
*navigableExprImpl
}
func (sel navigableSelectImpl) FieldName() string {
return sel.ToExpr().GetSelectExpr().GetField()
}
func (sel navigableSelectImpl) IsTestOnly() bool {
return sel.ToExpr().GetSelectExpr().GetTestOnly()
}
func (sel navigableSelectImpl) Operand() NavigableExpr {
return sel.createChild(sel.ToExpr().GetSelectExpr().GetOperand())
}
type navigableStructImpl struct {
*navigableExprImpl
}
func (s navigableStructImpl) TypeName() string {
return s.ToExpr().GetStructExpr().GetMessageName()
}
func (s navigableStructImpl) Fields() []NavigableField {
fieldInits := s.ToExpr().GetStructExpr().GetEntries()
fields := make([]NavigableField, len(fieldInits))
for i, f := range fieldInits {
fields[i] = navigableFieldImpl{
name: f.GetFieldKey(),
val: s.createChild(f.GetValue()),
isOpt: f.GetOptionalEntry(),
}
}
return fields
}
type navigableFieldImpl struct {
name string
val NavigableExpr
isOpt bool
}
func (f navigableFieldImpl) FieldName() string {
return f.name
}
func (f navigableFieldImpl) Value() NavigableExpr {
return f.val
}
func (f navigableFieldImpl) IsOptional() bool {
return f.isOpt
}
func (f navigableFieldImpl) isNavigable() {}
func kindOf(expr *exprpb.Expr) (ExprKind, childFactory) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
return LiteralKind, noopFactory
case *exprpb.Expr_IdentExpr:
return IdentKind, noopFactory
case *exprpb.Expr_SelectExpr:
return SelectKind, selectFactory
case *exprpb.Expr_CallExpr:
return CallKind, callArgFactory
case *exprpb.Expr_ListExpr:
return ListKind, listElemFactory
case *exprpb.Expr_StructExpr:
if expr.GetStructExpr().GetMessageName() != "" {
return StructKind, structEntryFactory
}
return MapKind, mapEntryFactory
case *exprpb.Expr_ComprehensionExpr:
return ComprehensionKind, comprehensionFactory
default:
return UnspecifiedKind, noopFactory
}
}
type childFactory func(*navigableExprImpl) []NavigableExpr
func noopFactory(*navigableExprImpl) []NavigableExpr {
return nil
}
func selectFactory(nav *navigableExprImpl) []NavigableExpr {
return []NavigableExpr{
nav.createChild(nav.ToExpr().GetSelectExpr().GetOperand()),
}
}
func callArgFactory(nav *navigableExprImpl) []NavigableExpr {
call := nav.ToExpr().GetCallExpr()
argCount := len(call.GetArgs())
if call.GetTarget() != nil {
argCount++
}
navExprs := make([]NavigableExpr, argCount)
i := 0
if call.GetTarget() != nil {
navExprs[i] = nav.createChild(call.GetTarget())
i++
}
for _, arg := range call.GetArgs() {
navExprs[i] = nav.createChild(arg)
i++
}
return navExprs
}
func listElemFactory(nav *navigableExprImpl) []NavigableExpr {
l := nav.ToExpr().GetListExpr()
navExprs := make([]NavigableExpr, len(l.GetElements()))
for i, e := range l.GetElements() {
navExprs[i] = nav.createChild(e)
}
return navExprs
}
func structEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.ToExpr().GetStructExpr()
entries := make([]NavigableExpr, len(s.GetEntries()))
for i, e := range s.GetEntries() {
entries[i] = nav.createChild(e.GetValue())
}
return entries
}
func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.ToExpr().GetStructExpr()
entries := make([]NavigableExpr, len(s.GetEntries())*2)
j := 0
for _, e := range s.GetEntries() {
entries[j] = nav.createChild(e.GetMapKey())
entries[j+1] = nav.createChild(e.GetValue())
j += 2
}
return entries
}
func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr {
compre := nav.ToExpr().GetComprehensionExpr()
return []NavigableExpr{
nav.createChild(compre.GetIterRange()),
nav.createChild(compre.GetAccuInit()),
nav.createChild(compre.GetLoopCondition()),
nav.createChild(compre.GetLoopStep()),
nav.createChild(compre.GetResult()),
}
}

View File

@ -0,0 +1,39 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"decls.go",
],
importpath = "github.com/google/cel-go/common/decls",
deps = [
"//checker/decls:go_default_library",
"//common/functions:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"decls_test.go",
],
embed = [":go_default_library"],
deps = [
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

844
vendor/github.com/google/cel-go/common/decls/decls.go generated vendored Normal file
View File

@ -0,0 +1,844 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package decls contains function and variable declaration structs and helper methods.
package decls
import (
"fmt"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// NewFunction creates a new function declaration with a set of function options to configure overloads
// and function definitions (implementations).
//
// Functions are checked for name collisions and singleton redefinition.
func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) {
fn := &FunctionDecl{
name: name,
overloads: map[string]*OverloadDecl{},
overloadOrdinals: []string{},
}
var err error
for _, opt := range opts {
fn, err = opt(fn)
if err != nil {
return nil, err
}
}
if len(fn.overloads) == 0 {
return nil, fmt.Errorf("function %s must have at least one overload", name)
}
return fn, nil
}
// FunctionDecl defines a function name, overload set, and optionally a singleton definition for all
// overload instances.
type FunctionDecl struct {
name string
// overloads associated with the function name.
overloads map[string]*OverloadDecl
// singleton implementation of the function for all overloads.
//
// If this option is set, an error will occur if any overloads specify a per-overload implementation
// or if another function with the same name attempts to redefine the singleton.
singleton *functions.Overload
// disableTypeGuards is a performance optimization to disable detailed runtime type checks which could
// add overhead on common operations. Setting this option true leaves error checks and argument checks
// intact.
disableTypeGuards bool
// state indicates that the binding should be provided as a declaration, as a runtime binding, or both.
state declarationState
// overloadOrdinals indicates the order in which the overload was declared.
overloadOrdinals []string
}
type declarationState int
const (
declarationStateUnset declarationState = iota
declarationDisabled
declarationEnabled
)
// Name returns the function name in human-readable terms, e.g. 'contains' of 'math.least'
func (f *FunctionDecl) Name() string {
if f == nil {
return ""
}
return f.name
}
// IsDeclarationDisabled indicates that the function implementation should be added to the dispatcher, but the
// declaration should not be exposed for use in expressions.
func (f *FunctionDecl) IsDeclarationDisabled() bool {
return f.state == declarationDisabled
}
// Merge combines an existing function declaration with another.
//
// If a function is extended, by say adding new overloads to an existing function, then it is merged with the
// prior definition of the function at which point its overloads must not collide with pre-existing overloads
// and its bindings (singleton, or per-overload) must not conflict with previous definitions either.
func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
if f == other {
return f, nil
}
if f.Name() != other.Name() {
return nil, fmt.Errorf("cannot merge unrelated functions. %s and %s", f.Name(), other.Name())
}
merged := &FunctionDecl{
name: f.Name(),
overloads: make(map[string]*OverloadDecl, len(f.overloads)),
singleton: f.singleton,
overloadOrdinals: make([]string, len(f.overloads)),
// if one function is expecting type-guards and the other is not, then they
// must not be disabled.
disableTypeGuards: f.disableTypeGuards && other.disableTypeGuards,
// default to the current functions declaration state.
state: f.state,
}
// If the other state indicates that the declaration should be explicitly enabled or
// disabled, then update the merged state with the most recent value.
if other.state != declarationStateUnset {
merged.state = other.state
}
// baseline copy of the overloads and their ordinals
copy(merged.overloadOrdinals, f.overloadOrdinals)
for oID, o := range f.overloads {
merged.overloads[oID] = o
}
// overloads and their ordinals are added from the left
for _, oID := range other.overloadOrdinals {
o := other.overloads[oID]
err := merged.AddOverload(o)
if err != nil {
return nil, fmt.Errorf("function declaration merge failed: %v", err)
}
}
if other.singleton != nil {
if merged.singleton != nil && merged.singleton != other.singleton {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
merged.singleton = other.singleton
}
return merged, nil
}
// AddOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
if f == nil {
return fmt.Errorf("nil function cannot add overload: %s", overload.ID())
}
for oID, o := range f.overloads {
if oID != overload.ID() && o.SignatureOverlaps(overload) {
return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.Name(), oID, overload.ID())
}
if oID == overload.ID() {
if o.SignatureEquals(overload) && o.IsNonStrict() == overload.IsNonStrict() {
// Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[oID] = overload
return nil
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
}
}
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
f.overloads[overload.ID()] = overload
return nil
}
// OverloadDecls returns the overload declarations in the order in which they were declared.
func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
if f == nil {
return []*OverloadDecl{}
}
overloads := make([]*OverloadDecl, 0, len(f.overloads))
for _, oID := range f.overloadOrdinals {
overloads = append(overloads, f.overloads[oID])
}
return overloads
}
// Bindings produces a set of function bindings, if any are defined.
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
if f == nil {
return []*functions.Overload{}, nil
}
overloads := []*functions.Overload{}
nonStrict := false
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
if o.hasBinding() {
overload := &functions.Overload{
Operator: o.ID(),
Unary: o.guardedUnaryOp(f.Name(), f.disableTypeGuards),
Binary: o.guardedBinaryOp(f.Name(), f.disableTypeGuards),
Function: o.guardedFunctionOp(f.Name(), f.disableTypeGuards),
OperandTrait: o.OperandTrait(),
NonStrict: o.IsNonStrict(),
}
overloads = append(overloads, overload)
nonStrict = nonStrict || o.IsNonStrict()
}
}
if f.singleton != nil {
if len(overloads) != 0 {
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
}
overloads = []*functions.Overload{
{
Operator: f.Name(),
Unary: f.singleton.Unary,
Binary: f.singleton.Binary,
Function: f.singleton.Function,
OperandTrait: f.singleton.OperandTrait,
},
}
// fall-through to return single overload case.
}
if len(overloads) == 0 {
return overloads, nil
}
// Single overload. Replicate an entry for it using the function name as well.
if len(overloads) == 1 {
if overloads[0].Operator == f.Name() {
return overloads, nil
}
return append(overloads, &functions.Overload{
Operator: f.Name(),
Unary: overloads[0].Unary,
Binary: overloads[0].Binary,
Function: overloads[0].Function,
NonStrict: overloads[0].NonStrict,
OperandTrait: overloads[0].OperandTrait,
}), nil
}
// All of the defined overloads are wrapped into a top-level function which
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.unaryOp(args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.binaryOp(args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.functionOp(args...)
}
// eventually this will fall through to the noSuchOverload below.
}
return MaybeNoSuchOverload(f.Name(), args...)
}
function := &functions.Overload{
Operator: f.Name(),
Function: funcDispatch,
NonStrict: nonStrict,
}
return append(bindings, function), nil
}
// MaybeNoSuchOverload determines whether to propagate an error if one is provided as an argument, or
// to return an unknown set, or to produce a new error for a missing function signature.
func MaybeNoSuchOverload(funcName string, args ...ref.Val) ref.Val {
argTypes := make([]string, len(args))
var unk *types.Unknown = nil
for i, arg := range args {
if types.IsError(arg) {
return arg
}
if types.IsUnknown(arg) {
unk = types.MergeUnknowns(arg.(*types.Unknown), unk)
}
argTypes[i] = arg.Type().TypeName()
}
if unk != nil {
return unk
}
signature := strings.Join(argTypes, ", ")
return types.NewErr("no such overload: %s(%s)", funcName, signature)
}
// FunctionOpt defines a functional option for mutating a function declaration.
type FunctionOpt func(*FunctionDecl) (*FunctionDecl, error)
// DisableTypeGuards disables automatically generated function invocation guards on direct overload calls.
// Type guards remain on during dynamic dispatch for parsed-only expressions.
func DisableTypeGuards(value bool) FunctionOpt {
return func(fn *FunctionDecl) (*FunctionDecl, error) {
fn.disableTypeGuards = value
return fn, nil
}
}
// DisableDeclaration indicates that the function declaration should be disabled, but the runtime function
// binding should be provided. Marking a function as runtime-only is a safe way to manage deprecations
// of function declarations while still preserving the runtime behavior for previously compiled expressions.
func DisableDeclaration(value bool) FunctionOpt {
return func(fn *FunctionDecl) (*FunctionDecl, error) {
if value {
fn.state = declarationDisabled
} else {
fn.state = declarationEnabled
}
return fn, nil
}
}
// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Unary: fn,
OperandTrait: trait,
}
return f, nil
}
}
// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Binary: fn,
OperandTrait: trait,
}
return f, nil
}
}
// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Function: fn,
OperandTrait: trait,
}
return f, nil
}
}
// Overload defines a new global overload with an overload id, argument types, and result type. Through the
// use of OverloadOpt options, the overload may also be configured with a binding, an operand trait, and to
// be non-strict.
//
// Note: function bindings should be commonly configured with Overload instances whereas operand traits and
// strict-ness should be rare occurrences.
func Overload(overloadID string,
args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return newOverload(overloadID, false, args, resultType, opts...)
}
// MemberOverload defines a new receiver-style overload (or member function) with an overload id, argument types,
// and result type. Through the use of OverloadOpt options, the overload may also be configured with a binding,
// an operand trait, and to be non-strict.
//
// Note: function bindings should be commonly configured with Overload instances whereas operand traits and
// strict-ness should be rare occurrences.
func MemberOverload(overloadID string,
args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return newOverload(overloadID, true, args, resultType, opts...)
}
func newOverload(overloadID string,
memberFunction bool, args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return func(f *FunctionDecl) (*FunctionDecl, error) {
overload, err := newOverloadInternal(overloadID, memberFunction, args, resultType, opts...)
if err != nil {
return nil, err
}
err = f.AddOverload(overload)
if err != nil {
return nil, err
}
return f, nil
}
}
func newOverloadInternal(overloadID string,
memberFunction bool, args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) (*OverloadDecl, error) {
overload := &OverloadDecl{
id: overloadID,
argTypes: args,
resultType: resultType,
isMemberFunction: memberFunction,
}
var err error
for _, opt := range opts {
overload, err = opt(overload)
if err != nil {
return nil, err
}
}
return overload, nil
}
// OverloadDecl contains the definition of a single overload id with a specific signature, and an optional
// implementation.
type OverloadDecl struct {
id string
argTypes []*types.Type
resultType *types.Type
isMemberFunction bool
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
nonStrict bool
// operandTrait indicates whether the member argument should have a specific type-trait.
//
// This is useful for creating overloads which operate on a type-interface rather than a concrete type.
operandTrait int
// Function implementation options. Optional, but encouraged.
// unaryOp is a function binding that takes a single argument.
unaryOp functions.UnaryOp
// binaryOp is a function binding that takes two arguments.
binaryOp functions.BinaryOp
// functionOp is a catch-all for zero-arity and three-plus arity functions.
functionOp functions.FunctionOp
}
// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
// and interpreter to optimize performance.
//
// The ID format is usually one of two styles:
// global: <functionName>_<argType>_<argTypeN>
// member: <memberType>_<functionName>_<argType>_<argTypeN>
func (o *OverloadDecl) ID() string {
if o == nil {
return ""
}
return o.id
}
// ArgTypes contains the set of argument types expected by the overload.
//
// For member functions ArgTypes[0] represents the member operand type.
func (o *OverloadDecl) ArgTypes() []*types.Type {
if o == nil {
return emptyArgs
}
return o.argTypes
}
// IsMemberFunction indicates whether the overload is a member function
func (o *OverloadDecl) IsMemberFunction() bool {
if o == nil {
return false
}
return o.isMemberFunction
}
// IsNonStrict returns whether the overload accepts errors and unknown values as arguments.
func (o *OverloadDecl) IsNonStrict() bool {
if o == nil {
return false
}
return o.nonStrict
}
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
// `traits.Indexer`
func (o *OverloadDecl) OperandTrait() int {
if o == nil {
return 0
}
return o.operandTrait
}
// ResultType indicates the output type from calling the function.
func (o *OverloadDecl) ResultType() *types.Type {
if o == nil {
// *types.Type is nil-safe
return nil
}
return o.resultType
}
// TypeParams returns the type parameter names associated with the overload.
func (o *OverloadDecl) TypeParams() []string {
typeParams := map[string]struct{}{}
collectParamNames(typeParams, o.ResultType())
for _, arg := range o.ArgTypes() {
collectParamNames(typeParams, arg)
}
params := make([]string, 0, len(typeParams))
for param := range typeParams {
params = append(params, param)
}
return params
}
// SignatureEquals determines whether the incoming overload declaration signature is equal to the current signature.
//
// Result type, operand trait, and strict-ness are not considered as part of signature equality.
func (o *OverloadDecl) SignatureEquals(other *OverloadDecl) bool {
if o == other {
return true
}
if o.ID() != other.ID() || o.IsMemberFunction() != other.IsMemberFunction() || len(o.ArgTypes()) != len(other.ArgTypes()) {
return false
}
for i, at := range o.ArgTypes() {
oat := other.ArgTypes()[i]
if !at.IsEquivalentType(oat) {
return false
}
}
return o.ResultType().IsEquivalentType(other.ResultType())
}
// SignatureOverlaps indicates whether two functions have non-equal, but overloapping function signatures.
//
// For example, list(dyn) collides with list(string) since the 'dyn' type can contain a 'string' type.
func (o *OverloadDecl) SignatureOverlaps(other *OverloadDecl) bool {
if o.IsMemberFunction() != other.IsMemberFunction() || len(o.ArgTypes()) != len(other.ArgTypes()) {
return false
}
argsOverlap := true
for i, argType := range o.ArgTypes() {
otherArgType := other.ArgTypes()[i]
argsOverlap = argsOverlap &&
(argType.IsAssignableType(otherArgType) ||
otherArgType.IsAssignableType(argType))
}
return argsOverlap
}
// hasBinding indicates whether the overload already has a definition.
func (o *OverloadDecl) hasBinding() bool {
return o != nil && (o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil)
}
// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp {
if o.unaryOp == nil {
return nil
}
return func(arg ref.Val) ref.Val {
if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) {
return MaybeNoSuchOverload(funcName, arg)
}
return o.unaryOp(arg)
}
}
// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp {
if o.binaryOp == nil {
return nil
}
return func(arg1, arg2 ref.Val) ref.Val {
if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) {
return MaybeNoSuchOverload(funcName, arg1, arg2)
}
return o.binaryOp(arg1, arg2)
}
}
// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp {
if o.functionOp == nil {
return nil
}
return func(args ...ref.Val) ref.Val {
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
return MaybeNoSuchOverload(funcName, args...)
}
return o.functionOp(args...)
}
}
// matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument.
func (o *OverloadDecl) matchesRuntimeUnarySignature(disableTypeGuards bool, arg ref.Val) bool {
return matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[0], arg) &&
matchOperandTrait(o.OperandTrait(), arg)
}
// matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *OverloadDecl) matchesRuntimeBinarySignature(disableTypeGuards bool, arg1, arg2 ref.Val) bool {
return matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[0], arg1) &&
matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[1], arg2) &&
matchOperandTrait(o.OperandTrait(), arg1)
}
// matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *OverloadDecl) matchesRuntimeSignature(disableTypeGuards bool, args ...ref.Val) bool {
if len(args) != len(o.ArgTypes()) {
return false
}
if len(args) == 0 {
return true
}
for i, arg := range args {
if !matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[i], arg) {
return false
}
}
return matchOperandTrait(o.OperandTrait(), args[0])
}
func matchRuntimeArgType(nonStrict, disableTypeGuards bool, argType *types.Type, arg ref.Val) bool {
if nonStrict && (disableTypeGuards || types.IsUnknownOrError(arg)) {
return true
}
if types.IsUnknownOrError(arg) {
return false
}
return disableTypeGuards || argType.IsAssignableRuntimeType(arg)
}
func matchOperandTrait(trait int, arg ref.Val) bool {
return trait == 0 || arg.Type().HasTrait(trait) || types.IsUnknownOrError(arg)
}
// OverloadOpt is a functional option for configuring a function overload.
type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if len(o.ArgTypes()) != 1 {
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
}
o.unaryOp = binding
return o, nil
}
}
// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if len(o.ArgTypes()) != 2 {
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
}
o.binaryOp = binding
return o, nil
}
}
// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
o.functionOp = binding
return o, nil
}
}
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
func OverloadIsNonStrict() OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
o.nonStrict = true
return o, nil
}
}
// OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be
// successfully invoked.
func OverloadOperandTrait(trait int) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
o.operandTrait = trait
return o, nil
}
}
// NewConstant creates a new constant declaration.
func NewConstant(name string, t *types.Type, v ref.Val) *VariableDecl {
return &VariableDecl{name: name, varType: t, value: v}
}
// NewVariable creates a new variable declaration.
func NewVariable(name string, t *types.Type) *VariableDecl {
return &VariableDecl{name: name, varType: t}
}
// VariableDecl defines a variable declaration which may optionally have a constant value.
type VariableDecl struct {
name string
varType *types.Type
value ref.Val
}
// Name returns the fully-qualified variable name
func (v *VariableDecl) Name() string {
if v == nil {
return ""
}
return v.name
}
// Type returns the types.Type value associated with the variable.
func (v *VariableDecl) Type() *types.Type {
if v == nil {
// types.Type is nil-safe
return nil
}
return v.varType
}
// Value returns the constant value associated with the declaration.
func (v *VariableDecl) Value() ref.Val {
if v == nil {
return nil
}
return v.value
}
// DeclarationIsEquivalent returns true if one variable declaration has the same name and same type as the input.
func (v *VariableDecl) DeclarationIsEquivalent(other *VariableDecl) bool {
if v == other {
return true
}
return v.Name() == other.Name() && v.Type().IsEquivalentType(other.Type())
}
// VariableDeclToExprDecl converts a go-native variable declaration into a protobuf-type variable declaration.
func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
varType, err := types.TypeToExprType(v.Type())
if err != nil {
return nil, err
}
return chkdecls.NewVar(v.Name(), varType), nil
}
// TypeVariable creates a new type identifier for use within a types.Provider
func TypeVariable(t *types.Type) *VariableDecl {
return NewVariable(t.TypeName(), types.NewTypeTypeWithParam(t))
}
// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.overloads))
for i, oID := range f.overloadOrdinals {
o := f.overloads[oID]
paramNames := map[string]struct{}{}
argTypes := make([]*exprpb.Type, len(o.ArgTypes()))
for j, a := range o.ArgTypes() {
collectParamNames(paramNames, a)
at, err := types.TypeToExprType(a)
if err != nil {
return nil, err
}
argTypes[j] = at
}
collectParamNames(paramNames, o.ResultType())
resultType, err := types.TypeToExprType(o.ResultType())
if err != nil {
return nil, err
}
if len(paramNames) == 0 {
if o.IsMemberFunction() {
overloads[i] = chkdecls.NewInstanceOverload(oID, argTypes, resultType)
} else {
overloads[i] = chkdecls.NewOverload(oID, argTypes, resultType)
}
} else {
params := []string{}
for pn := range paramNames {
params = append(params, pn)
}
if o.IsMemberFunction() {
overloads[i] = chkdecls.NewParameterizedInstanceOverload(oID, argTypes, resultType, params)
} else {
overloads[i] = chkdecls.NewParameterizedOverload(oID, argTypes, resultType, params)
}
}
}
return chkdecls.NewFunction(f.Name(), overloads...), nil
}
func collectParamNames(paramNames map[string]struct{}, arg *types.Type) {
if arg.Kind() == types.TypeParamKind {
paramNames[arg.TypeName()] = struct{}{}
}
for _, param := range arg.Parameters() {
collectParamNames(paramNames, param)
}
}
var (
emptyArgs = []*types.Type{}
)

View File

@ -22,10 +22,16 @@ import (
"golang.org/x/text/width"
)
// Error type which references a location within source and a message.
// NewError creates an error associated with an expression id with the given message at the given location.
func NewError(id int64, message string, location Location) *Error {
return &Error{Message: message, Location: location, ExprID: id}
}
// Error type which references an expression id, a location within source, and a message.
type Error struct {
Location Location
Message string
ExprID int64
}
const (

View File

@ -22,7 +22,7 @@ import (
// Errors type which contains a list of errors observed during parsing.
type Errors struct {
errors []Error
errors []*Error
source Source
numErrors int
maxErrorsToReport int
@ -31,7 +31,7 @@ type Errors struct {
// NewErrors creates a new instance of the Errors type.
func NewErrors(source Source) *Errors {
return &Errors{
errors: []Error{},
errors: []*Error{},
source: source,
maxErrorsToReport: 100,
}
@ -39,11 +39,17 @@ func NewErrors(source Source) *Errors {
// ReportError records an error at a source location.
func (e *Errors) ReportError(l Location, format string, args ...any) {
e.ReportErrorAtID(0, l, format, args...)
}
// ReportErrorAtID records an error at a source location and expression id.
func (e *Errors) ReportErrorAtID(id int64, l Location, format string, args ...any) {
e.numErrors++
if e.numErrors > e.maxErrorsToReport {
return
}
err := Error{
err := &Error{
ExprID: id,
Location: l,
Message: fmt.Sprintf(format, args...),
}
@ -51,12 +57,12 @@ func (e *Errors) ReportError(l Location, format string, args ...any) {
}
// GetErrors returns the list of observed errors.
func (e *Errors) GetErrors() []Error {
func (e *Errors) GetErrors() []*Error {
return e.errors[:]
}
// Append creates a new Errors object with the current and input errors.
func (e *Errors) Append(errs []Error) *Errors {
func (e *Errors) Append(errs []*Error) *Errors {
return &Errors{
errors: append(e.errors, errs...),
source: e.source,

View File

@ -0,0 +1,17 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"functions.go",
],
importpath = "github.com/google/cel-go/common/functions",
deps = [
"//common/types/ref:go_default_library",
],
)

View File

@ -0,0 +1,61 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package functions defines the standard builtin functions supported by the interpreter
package functions
import "github.com/google/cel-go/common/types/ref"
// Overload defines a named overload of a function, indicating an operand trait
// which must be present on the first argument to the overload as well as one
// of either a unary, binary, or function implementation.
//
// The majority of operators within the expression language are unary or binary
// and the specializations simplify the call contract for implementers of
// types with operator overloads. Any added complexity is assumed to be handled
// by the generic FunctionOp.
type Overload struct {
// Operator name as written in an expression or defined within
// operators.go.
Operator string
// Operand trait used to dispatch the call. The zero-value indicates a
// global function overload or that one of the Unary / Binary / Function
// definitions should be used to execute the call.
OperandTrait int
// Unary defines the overload with a UnaryOp implementation. May be nil.
Unary UnaryOp
// Binary defines the overload with a BinaryOp implementation. May be nil.
Binary BinaryOp
// Function defines the overload with a FunctionOp implementation. May be
// nil.
Function FunctionOp
// NonStrict specifies whether the Overload will tolerate arguments that
// are types.Err or types.Unknown.
NonStrict bool
}
// UnaryOp is a function that takes a single value and produces an output.
type UnaryOp func(value ref.Val) ref.Val
// BinaryOp is a function that takes two values and produces an output.
type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val
// FunctionOp is a function with accepts zero or more arguments and produces
// a value or error as a result.
type FunctionOp func(values ...ref.Val) ref.Val

View File

@ -64,7 +64,6 @@ type sourceImpl struct {
runes.Buffer
description string
lineOffsets []int32
idOffsets map[int64]int32
}
var _ runes.Buffer = &sourceImpl{}
@ -92,7 +91,6 @@ func NewStringSource(contents string, description string) Source {
Buffer: runes.NewBuffer(contents),
description: description,
lineOffsets: offsets,
idOffsets: map[int64]int32{},
}
}
@ -102,7 +100,6 @@ func NewInfoSource(info *exprpb.SourceInfo) Source {
Buffer: runes.NewBuffer(""),
description: info.GetLocation(),
lineOffsets: info.GetLineOffsets(),
idOffsets: info.GetPositions(),
}
}

View File

@ -0,0 +1,25 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"standard.go",
],
importpath = "github.com/google/cel-go/common/stdlib",
deps = [
"//checker/decls:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@ -0,0 +1,661 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package stdlib contains all of the standard library function declarations and definitions for CEL.
package stdlib
import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
stdFunctions []*decls.FunctionDecl
stdFnDecls []*exprpb.Decl
stdTypes []*decls.VariableDecl
stdTypeDecls []*exprpb.Decl
)
func init() {
paramA := types.NewTypeParamType("A")
paramB := types.NewTypeParamType("B")
listOfA := types.NewListType(paramA)
mapOfAB := types.NewMapType(paramA, paramB)
stdTypes = []*decls.VariableDecl{
decls.TypeVariable(types.BoolType),
decls.TypeVariable(types.BytesType),
decls.TypeVariable(types.DoubleType),
decls.TypeVariable(types.DurationType),
decls.TypeVariable(types.IntType),
decls.TypeVariable(listOfA),
decls.TypeVariable(mapOfAB),
decls.TypeVariable(types.NullType),
decls.TypeVariable(types.StringType),
decls.TypeVariable(types.TimestampType),
decls.TypeVariable(types.TypeType),
decls.TypeVariable(types.UintType),
}
stdTypeDecls = make([]*exprpb.Decl, 0, len(stdTypes))
for _, stdType := range stdTypes {
typeVar, err := decls.VariableDeclToExprDecl(stdType)
if err != nil {
panic(err)
}
stdTypeDecls = append(stdTypeDecls, typeVar)
}
stdFunctions = []*decls.FunctionDecl{
// Logical operators. Special-cased within the interpreter.
// Note, the singleton binding prevents extensions from overriding the operator behavior.
function(operators.Conditional,
decls.Overload(overloads.Conditional, argTypes(types.BoolType, paramA, paramA), paramA,
decls.OverloadIsNonStrict()),
decls.SingletonFunctionBinding(noFunctionOverrides)),
function(operators.LogicalAnd,
decls.Overload(overloads.LogicalAnd, argTypes(types.BoolType, types.BoolType), types.BoolType,
decls.OverloadIsNonStrict()),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.LogicalOr,
decls.Overload(overloads.LogicalOr, argTypes(types.BoolType, types.BoolType), types.BoolType,
decls.OverloadIsNonStrict()),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.LogicalNot,
decls.Overload(overloads.LogicalNot, argTypes(types.BoolType), types.BoolType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
b, ok := val.(types.Bool)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
return b.Negate()
})),
// Comprehension short-circuiting related function
function(operators.NotStrictlyFalse,
decls.Overload(overloads.NotStrictlyFalse, argTypes(types.BoolType), types.BoolType,
decls.OverloadIsNonStrict(),
decls.UnaryBinding(notStrictlyFalse))),
// Deprecated: __not_strictly_false__
function(operators.OldNotStrictlyFalse,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(operators.OldNotStrictlyFalse, argTypes(types.BoolType), types.BoolType,
decls.OverloadIsNonStrict(),
decls.UnaryBinding(notStrictlyFalse))),
// Equality / inequality. Special-cased in the interpreter
function(operators.Equals,
decls.Overload(overloads.Equals, argTypes(paramA, paramA), types.BoolType),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.NotEquals,
decls.Overload(overloads.NotEquals, argTypes(paramA, paramA), types.BoolType),
decls.SingletonBinaryBinding(noBinaryOverrides)),
// Mathematical operators
function(operators.Add,
decls.Overload(overloads.AddBytes,
argTypes(types.BytesType, types.BytesType), types.BytesType),
decls.Overload(overloads.AddDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.AddDurationDuration,
argTypes(types.DurationType, types.DurationType), types.DurationType),
decls.Overload(overloads.AddDurationTimestamp,
argTypes(types.DurationType, types.TimestampType), types.TimestampType),
decls.Overload(overloads.AddTimestampDuration,
argTypes(types.TimestampType, types.DurationType), types.TimestampType),
decls.Overload(overloads.AddInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.AddList,
argTypes(listOfA, listOfA), listOfA),
decls.Overload(overloads.AddString,
argTypes(types.StringType, types.StringType), types.StringType),
decls.Overload(overloads.AddUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Adder).Add(rhs)
}, traits.AdderType)),
function(operators.Divide,
decls.Overload(overloads.DivideDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.DivideInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.DivideUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Divider).Divide(rhs)
}, traits.DividerType)),
function(operators.Modulo,
decls.Overload(overloads.ModuloInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.ModuloUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Modder).Modulo(rhs)
}, traits.ModderType)),
function(operators.Multiply,
decls.Overload(overloads.MultiplyDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.MultiplyInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.MultiplyUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Multiplier).Multiply(rhs)
}, traits.MultiplierType)),
function(operators.Negate,
decls.Overload(overloads.NegateDouble, argTypes(types.DoubleType), types.DoubleType),
decls.Overload(overloads.NegateInt64, argTypes(types.IntType), types.IntType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
if types.IsBool(val) {
return types.MaybeNoSuchOverloadErr(val)
}
return val.(traits.Negater).Negate()
}, traits.NegatorType)),
function(operators.Subtract,
decls.Overload(overloads.SubtractDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.SubtractDurationDuration,
argTypes(types.DurationType, types.DurationType), types.DurationType),
decls.Overload(overloads.SubtractInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.SubtractTimestampDuration,
argTypes(types.TimestampType, types.DurationType), types.TimestampType),
decls.Overload(overloads.SubtractTimestampTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.DurationType),
decls.Overload(overloads.SubtractUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Subtractor).Subtract(rhs)
}, traits.SubtractorType)),
// Relations operators
function(operators.Less,
decls.Overload(overloads.LessBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.LessInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.LessInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.LessUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.LessUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.LessDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.LessDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.LessString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.LessBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.LessTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.LessDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne {
return types.True
}
if cmp == types.IntOne || cmp == types.IntZero {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.LessEquals,
decls.Overload(overloads.LessEqualsBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.LessEqualsBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.LessEqualsTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.LessEqualsDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntOne {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.Greater,
decls.Overload(overloads.GreaterBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.GreaterInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.GreaterBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.GreaterTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.GreaterDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne {
return types.True
}
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.GreaterEquals,
decls.Overload(overloads.GreaterEqualsBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.GreaterEqualsBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.GreaterEqualsTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntNegOne {
return types.False
}
return cmp
}, traits.ComparerType)),
// Indexing
function(operators.Index,
decls.Overload(overloads.IndexList, argTypes(listOfA, types.IntType), paramA),
decls.Overload(overloads.IndexMap, argTypes(mapOfAB, paramA), paramB),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Indexer).Get(rhs)
}, traits.IndexerType)),
// Collections operators
function(operators.In,
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(operators.OldIn,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(overloads.DeprecatedIn,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(overloads.Size,
decls.Overload(overloads.SizeBytes, argTypes(types.BytesType), types.IntType),
decls.MemberOverload(overloads.SizeBytesInst, argTypes(types.BytesType), types.IntType),
decls.Overload(overloads.SizeList, argTypes(listOfA), types.IntType),
decls.MemberOverload(overloads.SizeListInst, argTypes(listOfA), types.IntType),
decls.Overload(overloads.SizeMap, argTypes(mapOfAB), types.IntType),
decls.MemberOverload(overloads.SizeMapInst, argTypes(mapOfAB), types.IntType),
decls.Overload(overloads.SizeString, argTypes(types.StringType), types.IntType),
decls.MemberOverload(overloads.SizeStringInst, argTypes(types.StringType), types.IntType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
return val.(traits.Sizer).Size()
}, traits.SizerType)),
// Type conversions
function(overloads.TypeConvertType,
decls.Overload(overloads.TypeConvertType, argTypes(paramA), types.NewTypeTypeWithParam(paramA)),
decls.SingletonUnaryBinding(convertToType(types.TypeType))),
// Bool conversions
function(overloads.TypeConvertBool,
decls.Overload(overloads.BoolToBool, argTypes(types.BoolType), types.BoolType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.StringToBool, argTypes(types.StringType), types.BoolType,
decls.UnaryBinding(convertToType(types.BoolType)))),
// Bytes conversions
function(overloads.TypeConvertBytes,
decls.Overload(overloads.BytesToBytes, argTypes(types.BytesType), types.BytesType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.StringToBytes, argTypes(types.StringType), types.BytesType,
decls.UnaryBinding(convertToType(types.BytesType)))),
// Double conversions
function(overloads.TypeConvertDouble,
decls.Overload(overloads.DoubleToDouble, argTypes(types.DoubleType), types.DoubleType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToDouble, argTypes(types.IntType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType))),
decls.Overload(overloads.StringToDouble, argTypes(types.StringType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType))),
decls.Overload(overloads.UintToDouble, argTypes(types.UintType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType)))),
// Duration conversions
function(overloads.TypeConvertDuration,
decls.Overload(overloads.DurationToDuration, argTypes(types.DurationType), types.DurationType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToDuration, argTypes(types.IntType), types.DurationType,
decls.UnaryBinding(convertToType(types.DurationType))),
decls.Overload(overloads.StringToDuration, argTypes(types.StringType), types.DurationType,
decls.UnaryBinding(convertToType(types.DurationType)))),
// Dyn conversions
function(overloads.TypeConvertDyn,
decls.Overload(overloads.ToDyn, argTypes(paramA), types.DynType),
decls.SingletonUnaryBinding(identity)),
// Int conversions
function(overloads.TypeConvertInt,
decls.Overload(overloads.IntToInt, argTypes(types.IntType), types.IntType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.DoubleToInt, argTypes(types.DoubleType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.DurationToInt, argTypes(types.DurationType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.StringToInt, argTypes(types.StringType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.TimestampToInt, argTypes(types.TimestampType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.UintToInt, argTypes(types.UintType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
),
// String conversions
function(overloads.TypeConvertString,
decls.Overload(overloads.StringToString, argTypes(types.StringType), types.StringType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.BoolToString, argTypes(types.BoolType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.BytesToString, argTypes(types.BytesType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.DoubleToString, argTypes(types.DoubleType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.DurationToString, argTypes(types.DurationType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.IntToString, argTypes(types.IntType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.TimestampToString, argTypes(types.TimestampType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.UintToString, argTypes(types.UintType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType)))),
// Timestamp conversions
function(overloads.TypeConvertTimestamp,
decls.Overload(overloads.TimestampToTimestamp, argTypes(types.TimestampType), types.TimestampType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToTimestamp, argTypes(types.IntType), types.TimestampType,
decls.UnaryBinding(convertToType(types.TimestampType))),
decls.Overload(overloads.StringToTimestamp, argTypes(types.StringType), types.TimestampType,
decls.UnaryBinding(convertToType(types.TimestampType)))),
// Uint conversions
function(overloads.TypeConvertUint,
decls.Overload(overloads.UintToUint, argTypes(types.UintType), types.UintType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.DoubleToUint, argTypes(types.DoubleType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType))),
decls.Overload(overloads.IntToUint, argTypes(types.IntType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType))),
decls.Overload(overloads.StringToUint, argTypes(types.StringType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType)))),
// String functions
function(overloads.Contains,
decls.MemberOverload(overloads.ContainsString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringContains)),
decls.DisableTypeGuards(true)),
function(overloads.EndsWith,
decls.MemberOverload(overloads.EndsWithString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringEndsWith)),
decls.DisableTypeGuards(true)),
function(overloads.StartsWith,
decls.MemberOverload(overloads.StartsWithString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringStartsWith)),
decls.DisableTypeGuards(true)),
function(overloads.Matches,
decls.Overload(overloads.Matches, argTypes(types.StringType, types.StringType), types.BoolType),
decls.MemberOverload(overloads.MatchesString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.SingletonBinaryBinding(func(str, pat ref.Val) ref.Val {
return str.(traits.Matcher).Match(pat)
}, traits.MatcherType)),
// Timestamp / duration functions
function(overloads.TimeGetFullYear,
decls.MemberOverload(overloads.TimestampToYear,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToYearWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetMonth,
decls.MemberOverload(overloads.TimestampToMonth,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMonthWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfYear,
decls.MemberOverload(overloads.TimestampToDayOfYear,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfYearWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfMonth,
decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBased,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDate,
decls.MemberOverload(overloads.TimestampToDayOfMonthOneBased,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfWeek,
decls.MemberOverload(overloads.TimestampToDayOfWeek,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfWeekWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetHours,
decls.MemberOverload(overloads.TimestampToHours,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToHoursWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToHours,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetMinutes,
decls.MemberOverload(overloads.TimestampToMinutes,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMinutesWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToMinutes,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetSeconds,
decls.MemberOverload(overloads.TimestampToSeconds,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToSecondsWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToSeconds,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetMilliseconds,
decls.MemberOverload(overloads.TimestampToMilliseconds,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMillisecondsWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToMilliseconds,
argTypes(types.DurationType), types.IntType)),
}
stdFnDecls = make([]*exprpb.Decl, 0, len(stdFunctions))
for _, fn := range stdFunctions {
if fn.IsDeclarationDisabled() {
continue
}
ed, err := decls.FunctionDeclToExprDecl(fn)
if err != nil {
panic(err)
}
stdFnDecls = append(stdFnDecls, ed)
}
}
// Functions returns the set of standard library function declarations and definitions for CEL.
func Functions() []*decls.FunctionDecl {
return stdFunctions
}
// FunctionExprDecls returns the legacy style protobuf-typed declarations for all functions and overloads
// in the CEL standard environment.
//
// Deprecated: use Functions
func FunctionExprDecls() []*exprpb.Decl {
return stdFnDecls
}
// Types returns the set of standard library types for CEL.
func Types() []*decls.VariableDecl {
return stdTypes
}
// TypeExprDecls returns the legacy style protobuf-typed declarations for all types in the CEL
// standard environment.
//
// Deprecated: use Types
func TypeExprDecls() []*exprpb.Decl {
return stdTypeDecls
}
func notStrictlyFalse(value ref.Val) ref.Val {
if types.IsBool(value) {
return value
}
return types.True
}
func inAggregate(lhs ref.Val, rhs ref.Val) ref.Val {
if rhs.Type().HasTrait(traits.ContainerType) {
return rhs.(traits.Container).Contains(lhs)
}
return types.ValOrErr(rhs, "no such overload")
}
func function(name string, opts ...decls.FunctionOpt) *decls.FunctionDecl {
fn, err := decls.NewFunction(name, opts...)
if err != nil {
panic(err)
}
return fn
}
func argTypes(args ...*types.Type) []*types.Type {
return args
}
func noBinaryOverrides(rhs, lhs ref.Val) ref.Val {
return types.NoSuchOverloadErr()
}
func noFunctionOverrides(args ...ref.Val) ref.Val {
return types.NoSuchOverloadErr()
}
func identity(val ref.Val) ref.Val {
return val
}
func convertToType(t ref.Type) functions.UnaryOp {
return func(val ref.Val) ref.Val {
return val.ConvertToType(t)
}
}

View File

@ -27,20 +27,20 @@ go_library(
"provider.go",
"string.go",
"timestamp.go",
"type.go",
"types.go",
"uint.go",
"unknown.go",
"util.go",
],
importpath = "github.com/google/cel-go/common/types",
deps = [
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -71,8 +71,9 @@ go_test(
"provider_test.go",
"string_test.go",
"timestamp_test.go",
"type_test.go",
"types_test.go",
"uint_test.go",
"unknown_test.go",
"util_test.go",
],
embed = [":go_default_library"],

View File

@ -20,7 +20,6 @@ import (
"strconv"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -31,11 +30,6 @@ import (
type Bool bool
var (
// BoolType singleton.
BoolType = NewTypeValue("bool",
traits.ComparerType,
traits.NegatorType)
// boolWrapperType golang reflected type for protobuf bool wrapper type.
boolWrapperType = reflect.TypeOf(&wrapperspb.BoolValue{})
)

View File

@ -22,7 +22,6 @@ import (
"unicode/utf8"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -34,12 +33,6 @@ import (
type Bytes []byte
var (
// BytesType singleton.
BytesType = NewTypeValue("bytes",
traits.AdderType,
traits.ComparerType,
traits.SizerType)
// byteWrapperType golang reflected type for protobuf bytes wrapper type.
byteWrapperType = reflect.TypeOf(&wrapperspb.BytesValue{})
)

View File

@ -20,7 +20,6 @@ import (
"reflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -32,15 +31,6 @@ import (
type Double float64
var (
// DoubleType singleton.
DoubleType = NewTypeValue("double",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.MultiplierType,
traits.NegatorType,
traits.SubtractorType)
// doubleWrapperType reflected type for protobuf double wrapper type.
doubleWrapperType = reflect.TypeOf(&wrapperspb.DoubleValue{})

View File

@ -22,7 +22,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
dpb "google.golang.org/protobuf/types/known/durationpb"
@ -41,13 +40,14 @@ func durationOf(d time.Duration) Duration {
}
var (
// DurationType singleton.
DurationType = NewTypeValue("google.protobuf.Duration",
traits.AdderType,
traits.ComparerType,
traits.NegatorType,
traits.ReceiverType,
traits.SubtractorType)
durationValueType = reflect.TypeOf(&dpb.Duration{})
durationZeroArgOverloads = map[string]func(ref.Val) ref.Val{
overloads.TimeGetHours: DurationGetHours,
overloads.TimeGetMinutes: DurationGetMinutes,
overloads.TimeGetSeconds: DurationGetSeconds,
overloads.TimeGetMilliseconds: DurationGetMilliseconds,
}
)
// Add implements traits.Adder.Add.
@ -156,7 +156,7 @@ func (d Duration) Negate() ref.Val {
func (d Duration) Receive(function string, overload string, args []ref.Val) ref.Val {
if len(args) == 0 {
if f, found := durationZeroArgOverloads[function]; found {
return f(d.Duration)
return f(d)
}
}
return NoSuchOverloadErr()
@ -185,20 +185,38 @@ func (d Duration) Value() any {
return d.Duration
}
var (
durationValueType = reflect.TypeOf(&dpb.Duration{})
// DurationGetHours returns the duration in hours.
func DurationGetHours(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Hours())
}
durationZeroArgOverloads = map[string]func(time.Duration) ref.Val{
overloads.TimeGetHours: func(dur time.Duration) ref.Val {
return Int(dur.Hours())
},
overloads.TimeGetMinutes: func(dur time.Duration) ref.Val {
return Int(dur.Minutes())
},
overloads.TimeGetSeconds: func(dur time.Duration) ref.Val {
return Int(dur.Seconds())
},
overloads.TimeGetMilliseconds: func(dur time.Duration) ref.Val {
return Int(dur.Milliseconds())
}}
)
// DurationGetMinutes returns duration in minutes.
func DurationGetMinutes(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Minutes())
}
// DurationGetSeconds returns duration in seconds.
func DurationGetSeconds(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Seconds())
}
// DurationGetMilliseconds returns duration in milliseconds.
func DurationGetMilliseconds(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Milliseconds())
}

View File

@ -35,7 +35,7 @@ type Err struct {
var (
// ErrType singleton.
ErrType = NewTypeValue("error")
ErrType = NewOpaqueType("error")
// errDivideByZero is an error indicating a division by zero of an integer value.
errDivideByZero = errors.New("division by zero")
@ -129,6 +129,11 @@ func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()
}
// Unwrap implements errors.Unwrap.
func (e *Err) Unwrap() error {
return e.error
}
// IsError returns whether the input element ref.Type or ref.Val is equal to
// the ErrType singleton.
func IsError(val ref.Val) bool {

View File

@ -22,7 +22,6 @@ import (
"time"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -41,16 +40,6 @@ const (
)
var (
// IntType singleton.
IntType = NewTypeValue("int",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.ModderType,
traits.MultiplierType,
traits.NegatorType,
traits.SubtractorType)
// int32WrapperType reflected type for protobuf int32 wrapper type.
int32WrapperType = reflect.TypeOf(&wrapperspb.Int32Value{})

View File

@ -24,7 +24,7 @@ import (
var (
// IteratorType singleton.
IteratorType = NewTypeValue("iterator", traits.IteratorType)
IteratorType = NewObjectType("iterator", traits.IteratorType)
)
// baseIterator is the basis for list, map, and object iterators.

View File

@ -29,25 +29,15 @@ import (
structpb "google.golang.org/protobuf/types/known/structpb"
)
var (
// ListType singleton.
ListType = NewTypeValue("list",
traits.AdderType,
traits.ContainerType,
traits.IndexerType,
traits.IterableType,
traits.SizerType)
)
// NewDynamicList returns a traits.Lister with heterogenous elements.
// value should be an array of "native" types, i.e. any type that
// NativeToValue() can convert to a ref.Val.
func NewDynamicList(adapter ref.TypeAdapter, value any) traits.Lister {
func NewDynamicList(adapter Adapter, value any) traits.Lister {
refValue := reflect.ValueOf(value)
return &baseList{
TypeAdapter: adapter,
value: value,
size: refValue.Len(),
Adapter: adapter,
value: value,
size: refValue.Len(),
get: func(i int) any {
return refValue.Index(i).Interface()
},
@ -55,56 +45,56 @@ func NewDynamicList(adapter ref.TypeAdapter, value any) traits.Lister {
}
// NewStringList returns a traits.Lister containing only strings.
func NewStringList(adapter ref.TypeAdapter, elems []string) traits.Lister {
func NewStringList(adapter Adapter, elems []string) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
Adapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
}
}
// NewRefValList returns a traits.Lister with ref.Val elements.
//
// This type specialization is used with list literals within CEL expressions.
func NewRefValList(adapter ref.TypeAdapter, elems []ref.Val) traits.Lister {
func NewRefValList(adapter Adapter, elems []ref.Val) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
Adapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
}
}
// NewProtoList returns a traits.Lister based on a pb.List instance.
func NewProtoList(adapter ref.TypeAdapter, list protoreflect.List) traits.Lister {
func NewProtoList(adapter Adapter, list protoreflect.List) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: list,
size: list.Len(),
get: func(i int) any { return list.Get(i).Interface() },
Adapter: adapter,
value: list,
size: list.Len(),
get: func(i int) any { return list.Get(i).Interface() },
}
}
// NewJSONList returns a traits.Lister based on structpb.ListValue instance.
func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
func NewJSONList(adapter Adapter, l *structpb.ListValue) traits.Lister {
vals := l.GetValues()
return &baseList{
TypeAdapter: adapter,
value: l,
size: len(vals),
get: func(i int) any { return vals[i] },
Adapter: adapter,
value: l,
size: len(vals),
get: func(i int) any { return vals[i] },
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
func NewMutableList(adapter Adapter) traits.MutableLister {
var mutableValues []ref.Val
l := &mutableList{
baseList: &baseList{
TypeAdapter: adapter,
value: mutableValues,
size: 0,
Adapter: adapter,
value: mutableValues,
size: 0,
},
mutableValues: mutableValues,
}
@ -116,9 +106,9 @@ func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
// baseList points to a list containing elements of any type.
// The `value` is an array of native values, and refValue is its reflection object.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
// The `Adapter` enables native type to CEL type conversions.
type baseList struct {
ref.TypeAdapter
Adapter
value any
// size indicates the number of elements within the list.
@ -143,9 +133,9 @@ func (l *baseList) Add(other ref.Val) ref.Val {
return l
}
return &concatList{
TypeAdapter: l.TypeAdapter,
prevList: l,
nextList: otherList}
Adapter: l.Adapter,
prevList: l,
nextList: otherList}
}
// Contains implements the traits.Container interface method.
@ -322,13 +312,13 @@ func (l *mutableList) Add(other ref.Val) ref.Val {
func (l *mutableList) ToImmutableList() traits.Lister {
// The reference to internal state is guaranteed to be safe as this call is only performed
// when mutations have been completed.
return NewRefValList(l.TypeAdapter, l.mutableValues)
return NewRefValList(l.Adapter, l.mutableValues)
}
// concatList combines two list implementations together into a view.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
// The `Adapter` enables native type to CEL type conversions.
type concatList struct {
ref.TypeAdapter
Adapter
value any
prevList traits.Lister
nextList traits.Lister
@ -347,9 +337,9 @@ func (l *concatList) Add(other ref.Val) ref.Val {
return l
}
return &concatList{
TypeAdapter: l.TypeAdapter,
prevList: l,
nextList: otherList}
Adapter: l.Adapter,
prevList: l,
nextList: otherList}
}
// Contains implements the traits.Container interface method.
@ -376,7 +366,7 @@ func (l *concatList) Contains(elem ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method.
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (any, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]any))
combined := NewDynamicList(l.Adapter, l.Value().([]any))
return combined.ConvertToNative(typeDesc)
}

View File

@ -32,10 +32,10 @@ import (
)
// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
func NewDynamicMap(adapter ref.TypeAdapter, value any) traits.Mapper {
func NewDynamicMap(adapter Adapter, value any) traits.Mapper {
refValue := reflect.ValueOf(value)
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newReflectMapAccessor(adapter, refValue),
value: value,
size: refValue.Len(),
@ -46,10 +46,10 @@ func NewDynamicMap(adapter ref.TypeAdapter, value any) traits.Mapper {
// encoded in protocol buffer form.
//
// The `adapter` argument provides type adaptation capabilities from proto to CEL.
func NewJSONStruct(adapter ref.TypeAdapter, value *structpb.Struct) traits.Mapper {
func NewJSONStruct(adapter Adapter, value *structpb.Struct) traits.Mapper {
fields := value.GetFields()
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newJSONStructAccessor(adapter, fields),
value: value,
size: len(fields),
@ -57,9 +57,9 @@ func NewJSONStruct(adapter ref.TypeAdapter, value *structpb.Struct) traits.Mappe
}
// NewRefValMap returns a specialized traits.Mapper with CEL valued keys and values.
func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Mapper {
func NewRefValMap(adapter Adapter, value map[ref.Val]ref.Val) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newRefValMapAccessor(value),
value: value,
size: len(value),
@ -67,9 +67,9 @@ func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Map
}
// NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values.
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]any) traits.Mapper {
func NewStringInterfaceMap(adapter Adapter, value map[string]any) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newStringIfaceMapAccessor(adapter, value),
value: value,
size: len(value),
@ -77,9 +77,9 @@ func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]any) traits
}
// NewStringStringMap returns a specialized traits.Mapper with string keys and values.
func NewStringStringMap(adapter ref.TypeAdapter, value map[string]string) traits.Mapper {
func NewStringStringMap(adapter Adapter, value map[string]string) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newStringMapAccessor(value),
value: value,
size: len(value),
@ -87,22 +87,13 @@ func NewStringStringMap(adapter ref.TypeAdapter, value map[string]string) traits
}
// NewProtoMap returns a specialized traits.Mapper for handling protobuf map values.
func NewProtoMap(adapter ref.TypeAdapter, value *pb.Map) traits.Mapper {
func NewProtoMap(adapter Adapter, value *pb.Map) traits.Mapper {
return &protoMap{
TypeAdapter: adapter,
value: value,
Adapter: adapter,
value: value,
}
}
var (
// MapType singleton.
MapType = NewTypeValue("map",
traits.ContainerType,
traits.IndexerType,
traits.IterableType,
traits.SizerType)
)
// mapAccessor is a private interface for finding values within a map and iterating over the keys.
// This interface implements portions of the API surface area required by the traits.Mapper
// interface.
@ -121,7 +112,7 @@ type mapAccessor interface {
// Since CEL is side-effect free, the base map represents an immutable object.
type baseMap struct {
// TypeAdapter used to convert keys and values accessed within the map.
ref.TypeAdapter
Adapter
// mapAccessor interface implementation used to find and iterate over map keys.
mapAccessor
@ -316,15 +307,15 @@ func (m *baseMap) Value() any {
return m.value
}
func newJSONStructAccessor(adapter ref.TypeAdapter, st map[string]*structpb.Value) mapAccessor {
func newJSONStructAccessor(adapter Adapter, st map[string]*structpb.Value) mapAccessor {
return &jsonStructAccessor{
TypeAdapter: adapter,
st: st,
Adapter: adapter,
st: st,
}
}
type jsonStructAccessor struct {
ref.TypeAdapter
Adapter
st map[string]*structpb.Value
}
@ -359,17 +350,17 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator {
}
}
func newReflectMapAccessor(adapter ref.TypeAdapter, value reflect.Value) mapAccessor {
func newReflectMapAccessor(adapter Adapter, value reflect.Value) mapAccessor {
keyType := value.Type().Key()
return &reflectMapAccessor{
TypeAdapter: adapter,
refValue: value,
keyType: keyType,
Adapter: adapter,
refValue: value,
keyType: keyType,
}
}
type reflectMapAccessor struct {
ref.TypeAdapter
Adapter
refValue reflect.Value
keyType reflect.Type
}
@ -427,9 +418,9 @@ func (m *reflectMapAccessor) findInternal(key ref.Val) (ref.Val, bool) {
// Iterator creates a Golang reflection based traits.Iterator.
func (m *reflectMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: m.TypeAdapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
Adapter: m.Adapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
}
}
@ -480,9 +471,9 @@ func (a *refValMapAccessor) Find(key ref.Val) (ref.Val, bool) {
// Iterator produces a new traits.Iterator which iterates over the map keys via Golang reflection.
func (a *refValMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: DefaultTypeAdapter,
mapKeys: reflect.ValueOf(a.mapVal).MapRange(),
len: len(a.mapVal),
Adapter: DefaultTypeAdapter,
mapKeys: reflect.ValueOf(a.mapVal).MapRange(),
len: len(a.mapVal),
}
}
@ -524,15 +515,15 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]any) mapAccessor {
func newStringIfaceMapAccessor(adapter Adapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
TypeAdapter: adapter,
mapVal: mapVal,
Adapter: adapter,
mapVal: mapVal,
}
}
type stringIfaceMapAccessor struct {
ref.TypeAdapter
Adapter
mapVal map[string]any
}
@ -569,7 +560,7 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator {
// protoMap is a specialized, separate implementation of the traits.Mapper interfaces tailored to
// accessing protoreflect.Map values.
type protoMap struct {
ref.TypeAdapter
Adapter
value *pb.Map
}
@ -772,9 +763,9 @@ func (m *protoMap) Iterator() traits.Iterator {
return true
})
return &protoMapIterator{
TypeAdapter: m.TypeAdapter,
mapKeys: mapKeys,
len: m.value.Len(),
Adapter: m.Adapter,
mapKeys: mapKeys,
len: m.value.Len(),
}
}
@ -795,7 +786,7 @@ func (m *protoMap) Value() any {
type mapIterator struct {
*baseIterator
ref.TypeAdapter
Adapter
mapKeys *reflect.MapIter
cursor int
len int
@ -818,7 +809,7 @@ func (it *mapIterator) Next() ref.Val {
type protoMapIterator struct {
*baseIterator
ref.TypeAdapter
Adapter
mapKeys []protoreflect.MapKey
cursor int
len int

View File

@ -30,8 +30,6 @@ import (
type Null structpb.NullValue
var (
// NullType singleton.
NullType = NewTypeValue("null_type")
// NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE)

View File

@ -29,10 +29,10 @@ import (
)
type protoObj struct {
ref.TypeAdapter
Adapter
value proto.Message
typeDesc *pb.TypeDescription
typeValue *TypeValue
typeValue ref.Val
}
// NewObject returns an object based on a proto.Message value which handles
@ -42,15 +42,15 @@ type protoObj struct {
// Note: the type value is pulled from the list of registered types within the
// type provider. If the proto type is not registered within the type provider,
// then this will result in an error within the type adapter / provider.
func NewObject(adapter ref.TypeAdapter,
func NewObject(adapter Adapter,
typeDesc *pb.TypeDescription,
typeValue *TypeValue,
typeValue ref.Val,
value proto.Message) ref.Val {
return &protoObj{
TypeAdapter: adapter,
value: value,
typeDesc: typeDesc,
typeValue: typeValue}
Adapter: adapter,
value: value,
typeDesc: typeDesc,
typeValue: typeValue}
}
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
@ -157,7 +157,7 @@ func (o *protoObj) Get(index ref.Val) ref.Val {
}
func (o *protoObj) Type() ref.Type {
return o.typeValue
return o.typeValue.(ref.Type)
}
func (o *protoObj) Value() any {

View File

@ -24,7 +24,7 @@ import (
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewTypeValue("optional")
OptionalType = NewOpaqueType("optional")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}

View File

@ -285,7 +285,7 @@ func (fd *FieldDescription) GetFrom(target any) (any, error) {
// IsEnum returns true if the field type refers to an enum value.
func (fd *FieldDescription) IsEnum() bool {
return fd.desc.Kind() == protoreflect.EnumKind
return fd.ProtoKind() == protoreflect.EnumKind
}
// IsMap returns true if the field is of map type.
@ -295,7 +295,7 @@ func (fd *FieldDescription) IsMap() bool {
// IsMessage returns true if the field is of message type.
func (fd *FieldDescription) IsMessage() bool {
kind := fd.desc.Kind()
kind := fd.ProtoKind()
return kind == protoreflect.MessageKind || kind == protoreflect.GroupKind
}
@ -326,6 +326,11 @@ func (fd *FieldDescription) Name() string {
return string(fd.desc.Name())
}
// ProtoKind returns the protobuf reflected kind of the field.
func (fd *FieldDescription) ProtoKind() protoreflect.Kind {
return fd.desc.Kind()
}
// ReflectType returns the Golang reflect.Type for this field.
func (fd *FieldDescription) ReflectType() reflect.Type {
return fd.reflectType
@ -345,17 +350,17 @@ func (fd *FieldDescription) Zero() proto.Message {
}
func (fd *FieldDescription) typeDefToType() *exprpb.Type {
if fd.desc.Kind() == protoreflect.MessageKind || fd.desc.Kind() == protoreflect.GroupKind {
if fd.IsMessage() {
msgType := string(fd.desc.Message().FullName())
if wk, found := CheckedWellKnowns[msgType]; found {
return wk
}
return checkedMessageType(msgType)
}
if fd.desc.Kind() == protoreflect.EnumKind {
if fd.IsEnum() {
return checkedInt
}
return CheckedPrimitives[fd.desc.Kind()]
return CheckedPrimitives[fd.ProtoKind()]
}
// Map wraps the protoreflect.Map object with a key and value FieldDescription for use in
@ -463,13 +468,13 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, er
unwrappedAny := &anypb.Any{}
err := Merge(unwrappedAny, msg)
if err != nil {
return nil, false, err
return nil, false, fmt.Errorf("unwrap dynamic field failed: %v", err)
}
dynMsg, err := unwrappedAny.UnmarshalNew()
if err != nil {
// Allow the error to move further up the stack as it should result in an type
// conversion error if the caller does not recover it somehow.
return nil, false, err
return nil, false, fmt.Errorf("unmarshal dynamic any failed: %v", err)
}
// Attempt to unwrap the dynamic type, otherwise return the dynamic message.
unwrapped, nested, err := unwrapDynamic(desc, dynMsg.ProtoReflect())
@ -560,8 +565,10 @@ func zeroValueOf(msg proto.Message) proto.Message {
}
var (
jsonValueTypeURL = "types.googleapis.com/google.protobuf.Value"
zeroValueMap = map[string]proto.Message{
"google.protobuf.Any": &anypb.Any{},
"google.protobuf.Any": &anypb.Any{TypeUrl: jsonValueTypeURL},
"google.protobuf.Duration": &dpb.Duration{},
"google.protobuf.ListValue": &structpb.ListValue{},
"google.protobuf.Struct": &structpb.Struct{},

View File

@ -33,17 +33,64 @@ import (
tpb "google.golang.org/protobuf/types/known/timestamppb"
)
type protoTypeRegistry struct {
revTypeMap map[string]ref.Type
// Adapter converts native Go values of varying type and complexity to equivalent CEL values.
type Adapter = ref.TypeAdapter
// Provider specifies functions for creating new object instances and for resolving
// enum values by name.
type Provider interface {
// EnumValue returns the numeric value of the given enum value name.
EnumValue(enumName string) ref.Val
// FindIdent takes a qualified identifier name and returns a ref.Val if one exists.
FindIdent(identName string) (ref.Val, bool)
// FindStructType returns the Type give a qualified type name.
//
// For historical reasons, only struct types are expected to be returned through this
// method, and the type values are expected to be wrapped in a TypeType instance using
// TypeTypeWithParam(<structType>).
//
// Returns false if not found.
FindStructType(structType string) (*Type, bool)
// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
// NewValue creates a new type value from a qualified name and map of field
// name to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
NewValue(structType string, fields map[string]ref.Val) ref.Val
}
// FieldType represents a field's type value and whether that field supports presence detection.
type FieldType struct {
// Type of the field as a CEL native type value.
Type *Type
// IsSet indicates whether the field is set on an input object.
IsSet ref.FieldTester
// GetFrom retrieves the field value on the input object, if set.
GetFrom ref.FieldGetter
}
// Registry provides type information for a set of registered types.
type Registry struct {
revTypeMap map[string]*Type
pbdb *pb.Db
}
// NewRegistry accepts a list of proto message instances and returns a type
// provider which can create new instances of the provided message or any
// message that proto depends upon in its FileDescriptor.
func NewRegistry(types ...proto.Message) (ref.TypeRegistry, error) {
p := &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
func NewRegistry(types ...proto.Message) (*Registry, error) {
p := &Registry{
revTypeMap: make(map[string]*Type),
pbdb: pb.NewDb(),
}
err := p.RegisterType(
@ -79,18 +126,17 @@ func NewRegistry(types ...proto.Message) (ref.TypeRegistry, error) {
}
// NewEmptyRegistry returns a registry which is completely unconfigured.
func NewEmptyRegistry() ref.TypeRegistry {
return &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
func NewEmptyRegistry() *Registry {
return &Registry{
revTypeMap: make(map[string]*Type),
pbdb: pb.NewDb(),
}
}
// Copy implements the ref.TypeRegistry interface method which copies the current state of the
// registry into its own memory space.
func (p *protoTypeRegistry) Copy() ref.TypeRegistry {
copy := &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
// Copy copies the current state of the registry into its own memory space.
func (p *Registry) Copy() *Registry {
copy := &Registry{
revTypeMap: make(map[string]*Type),
pbdb: p.pbdb.Copy(),
}
for k, v := range p.revTypeMap {
@ -99,7 +145,8 @@ func (p *protoTypeRegistry) Copy() ref.TypeRegistry {
return copy
}
func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val {
// EnumValue returns the numeric value of the given enum value name.
func (p *Registry) EnumValue(enumName string) ref.Val {
enumVal, found := p.pbdb.DescribeEnum(enumName)
if !found {
return NewErr("unknown enum name '%s'", enumName)
@ -107,9 +154,12 @@ func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val {
return Int(enumVal.Value())
}
func (p *protoTypeRegistry) FindFieldType(messageType string,
fieldName string) (*ref.FieldType, bool) {
msgType, found := p.pbdb.DescribeType(messageType)
// FieldFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
//
// Deprecated: use FindStructFieldType
func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return nil, false
}
@ -118,15 +168,32 @@ func (p *protoTypeRegistry) FindFieldType(messageType string,
return nil, false
}
return &ref.FieldType{
Type: field.CheckedType(),
IsSet: field.IsSet,
GetFrom: field.GetFrom},
true
Type: field.CheckedType(),
IsSet: field.IsSet,
GetFrom: field.GetFrom}, true
}
func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) {
// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return nil, false
}
field, found := msgType.FieldByName(fieldName)
if !found {
return nil, false
}
return &FieldType{
Type: fieldDescToCELType(field),
IsSet: field.IsSet,
GetFrom: field.GetFrom}, true
}
// FindIdent takes a qualified identifier name and returns a ref.Val if one exists.
func (p *Registry) FindIdent(identName string) (ref.Val, bool) {
if t, found := p.revTypeMap[identName]; found {
return t.(ref.Val), true
return t, true
}
if enumVal, found := p.pbdb.DescribeEnum(identName); found {
return Int(enumVal.Value()), true
@ -134,24 +201,50 @@ func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) {
return nil, false
}
func (p *protoTypeRegistry) FindType(typeName string) (*exprpb.Type, bool) {
if _, found := p.pbdb.DescribeType(typeName); !found {
// FindType looks up the Type given a qualified typeName. Returns false if not found.
//
// Deprecated: use FindStructType
func (p *Registry) FindType(structType string) (*exprpb.Type, bool) {
if _, found := p.pbdb.DescribeType(structType); !found {
return nil, false
}
if typeName != "" && typeName[0] == '.' {
typeName = typeName[1:]
if structType != "" && structType[0] == '.' {
structType = structType[1:]
}
return &exprpb.Type{
TypeKind: &exprpb.Type_Type{
Type: &exprpb.Type{
TypeKind: &exprpb.Type_MessageType{
MessageType: typeName}}}}, true
MessageType: structType}}}}, true
}
func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val) ref.Val {
td, found := p.pbdb.DescribeType(typeName)
// FindStructType returns the Type give a qualified type name.
//
// For historical reasons, only struct types are expected to be returned through this
// method, and the type values are expected to be wrapped in a TypeType instance using
// TypeTypeWithParam(<structType>).
//
// Returns false if not found.
func (p *Registry) FindStructType(structType string) (*Type, bool) {
if _, found := p.pbdb.DescribeType(structType); !found {
return nil, false
}
if structType != "" && structType[0] == '.' {
structType = structType[1:]
}
return NewTypeTypeWithParam(NewObjectType(structType)), true
}
// NewValue creates a new type value from a qualified name and map of field
// name to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
func (p *Registry) NewValue(structType string, fields map[string]ref.Val) ref.Val {
td, found := p.pbdb.DescribeType(structType)
if !found {
return NewErr("unknown type '%s'", typeName)
return NewErr("unknown type '%s'", structType)
}
msg := td.New()
fieldMap := td.FieldMap()
@ -168,7 +261,8 @@ func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val)
return p.NativeToValue(msg.Interface())
}
func (p *protoTypeRegistry) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) error {
// RegisterDescriptor registers the contents of a protocol buffer `FileDescriptor`.
func (p *Registry) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) error {
fd, err := p.pbdb.RegisterDescriptor(fileDesc)
if err != nil {
return err
@ -176,7 +270,8 @@ func (p *protoTypeRegistry) RegisterDescriptor(fileDesc protoreflect.FileDescrip
return p.registerAllTypes(fd)
}
func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error {
// RegisterMessage registers a protocol buffer message and its dependencies.
func (p *Registry) RegisterMessage(message proto.Message) error {
fd, err := p.pbdb.RegisterMessage(message)
if err != nil {
return err
@ -184,11 +279,32 @@ func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error {
return p.registerAllTypes(fd)
}
func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
// RegisterType registers a type value with the provider which ensures the provider is aware of how to
// map the type to an identifier.
//
// If the `ref.Type` value is a `*types.Type` it will be registered directly by its runtime type name.
// If the `ref.Type` value is not a `*types.Type` instance, a `*types.Type` instance which reflects the
// traits present on the input and the runtime type name. By default this foreign type will be treated
// as a types.StructKind. To avoid potential issues where the `ref.Type` values does not match the
// generated `*types.Type` instance, consider always using the `*types.Type` to represent type extensions
// to CEL, even when they're not based on protobuf types.
func (p *Registry) RegisterType(types ...ref.Type) error {
for _, t := range types {
p.revTypeMap[t.TypeName()] = t
celType := maybeForeignType(t)
existing, found := p.revTypeMap[t.TypeName()]
if !found {
p.revTypeMap[t.TypeName()] = celType
continue
}
if !existing.IsEquivalentType(celType) {
return fmt.Errorf("type registration conflict. found: %v, input: %v", existing, celType)
}
if existing.traitMask != celType.traitMask {
return fmt.Errorf(
"type registered with conflicting traits: %v with traits %v, input: %v",
existing.TypeName(), existing.traitMask, celType.traitMask)
}
}
// TODO: generate an error when the type name is registered more than once.
return nil
}
@ -196,7 +312,7 @@ func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
// providing support for custom proto-based types.
//
// This method should be the inverse of ref.Val.ConvertToNative.
func (p *protoTypeRegistry) NativeToValue(value any) ref.Val {
func (p *Registry) NativeToValue(value any) ref.Val {
if val, found := nativeToValue(p, value); found {
return val
}
@ -218,7 +334,7 @@ func (p *protoTypeRegistry) NativeToValue(value any) ref.Val {
if !found {
return NewErr("unknown type: '%s'", typeName)
}
return NewObject(p, td, typeVal.(*TypeValue), v)
return NewObject(p, td, typeVal, v)
case *pb.Map:
return NewProtoMap(p, v)
case protoreflect.List:
@ -231,8 +347,13 @@ func (p *protoTypeRegistry) NativeToValue(value any) ref.Val {
return UnsupportedRefValConversionErr(value)
}
func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error {
func (p *Registry) registerAllTypes(fd *pb.FileDescription) error {
for _, typeName := range fd.GetTypeNames() {
// skip well-known type names since they're automatically sanitized
// during NewObjectType() calls.
if _, found := checkedWellKnowns[typeName]; found {
continue
}
err := p.RegisterType(NewObjectTypeValue(typeName))
if err != nil {
return err
@ -241,6 +362,28 @@ func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error {
return nil
}
func fieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsMap() {
return NewMapType(
singularFieldDescToCELType(field.KeyType),
singularFieldDescToCELType(field.ValueType))
}
if field.IsList() {
return NewListType(singularFieldDescToCELType(field))
}
return singularFieldDescToCELType(field)
}
func singularFieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsMessage() {
return NewObjectType(string(field.Descriptor().Message().FullName()))
}
if field.IsEnum() {
return IntType
}
return ProtoCELPrimitives[field.ProtoKind()]
}
// defaultTypeAdapter converts go native types to CEL values.
type defaultTypeAdapter struct{}
@ -259,7 +402,7 @@ func (a *defaultTypeAdapter) NativeToValue(value any) ref.Val {
// nativeToValue returns the converted (ref.Val, true) of a conversion is found,
// otherwise (nil, false)
func nativeToValue(a ref.TypeAdapter, value any) (ref.Val, bool) {
func nativeToValue(a Adapter, value any) (ref.Val, bool) {
switch v := value.(type) {
case nil:
return NullValue, true
@ -547,3 +690,24 @@ func fieldTypeConversionError(field *pb.FieldDescription, err error) error {
msgName := field.Descriptor().ContainingMessage().FullName()
return fmt.Errorf("field type conversion error for %v.%v value type: %v", msgName, field.Name(), err)
}
var (
// ProtoCELPrimitives provides a map from the protoreflect Kind to the equivalent CEL type.
ProtoCELPrimitives = map[protoreflect.Kind]*Type{
protoreflect.BoolKind: BoolType,
protoreflect.BytesKind: BytesType,
protoreflect.DoubleKind: DoubleType,
protoreflect.FloatKind: DoubleType,
protoreflect.Int32Kind: IntType,
protoreflect.Int64Kind: IntType,
protoreflect.Sint32Kind: IntType,
protoreflect.Sint64Kind: IntType,
protoreflect.Uint32Kind: UintType,
protoreflect.Uint64Kind: UintType,
protoreflect.Fixed32Kind: UintType,
protoreflect.Fixed64Kind: UintType,
protoreflect.Sfixed32Kind: IntType,
protoreflect.Sfixed64Kind: IntType,
protoreflect.StringKind: StringType,
}
)

View File

@ -23,34 +23,34 @@ import (
// TypeProvider specifies functions for creating new object instances and for
// resolving enum values by name.
//
// Deprecated: use types.Provider
type TypeProvider interface {
// EnumValue returns the numeric value of the given enum value name.
EnumValue(enumName string) Val
// FindIdent takes a qualified identifier name and returns a Value if one
// exists.
// FindIdent takes a qualified identifier name and returns a Value if one exists.
FindIdent(identName string) (Val, bool)
// FindType looks up the Type given a qualified typeName. Returns false
// if not found.
//
// Used during type-checking only.
// FindType looks up the Type given a qualified typeName. Returns false if not found.
FindType(typeName string) (*exprpb.Type, bool)
// FieldFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindFieldType(messageType string, fieldName string) (*FieldType, bool)
// FieldFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
FindFieldType(messageType, fieldName string) (*FieldType, bool)
// NewValue creates a new type value from a qualified name and map of field
// name to value.
// NewValue creates a new type value from a qualified name and map of field name
// to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
// Note, for each value, the Val.ConvertToNative function will be invoked to convert
// the Val to the field's native type. If an error occurs during conversion, the
// NewValue will be a types.Err.
NewValue(typeName string, fields map[string]Val) Val
}
// TypeAdapter converts native Go values of varying type and complexity to equivalent CEL values.
//
// Deprecated: use types.Adapter
type TypeAdapter interface {
// NativeToValue converts the input `value` to a CEL `ref.Val`.
NativeToValue(value any) Val
@ -60,6 +60,8 @@ type TypeAdapter interface {
// implementations support type-customization, so these features are optional. However, a
// `TypeRegistry` should be a `TypeProvider` and a `TypeAdapter` to ensure that types
// which are registered can be converted to CEL representations.
//
// Deprecated: use types.Registry
type TypeRegistry interface {
TypeAdapter
TypeProvider
@ -76,15 +78,14 @@ type TypeRegistry interface {
// If a type is provided more than once with an alternative definition, the
// call will result in an error.
RegisterType(types ...Type) error
// Copy the TypeRegistry and return a new registry whose mutable state is isolated.
Copy() TypeRegistry
}
// FieldType represents a field's type value and whether that field supports
// presence detection.
//
// Deprecated: use types.FieldType
type FieldType struct {
// Type of the field.
// Type of the field as a protobuf type value.
Type *exprpb.Type
// IsSet indicates whether the field is set on an input object.

View File

@ -24,7 +24,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -36,18 +35,10 @@ import (
type String string
var (
// StringType singleton.
StringType = NewTypeValue("string",
traits.AdderType,
traits.ComparerType,
traits.MatcherType,
traits.ReceiverType,
traits.SizerType)
stringOneArgOverloads = map[string]func(String, ref.Val) ref.Val{
overloads.Contains: stringContains,
overloads.EndsWith: stringEndsWith,
overloads.StartsWith: stringStartsWith,
stringOneArgOverloads = map[string]func(ref.Val, ref.Val) ref.Val{
overloads.Contains: StringContains,
overloads.EndsWith: StringEndsWith,
overloads.StartsWith: StringStartsWith,
}
stringWrapperType = reflect.TypeOf(&wrapperspb.StringValue{})
@ -198,26 +189,41 @@ func (s String) Value() any {
return string(s)
}
func stringContains(s String, sub ref.Val) ref.Val {
// StringContains returns whether the string contains a substring.
func StringContains(s, sub ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
subStr, ok := sub.(String)
if !ok {
return MaybeNoSuchOverloadErr(sub)
}
return Bool(strings.Contains(string(s), string(subStr)))
return Bool(strings.Contains(string(str), string(subStr)))
}
func stringEndsWith(s String, suf ref.Val) ref.Val {
// StringEndsWith returns whether the target string contains the input suffix.
func StringEndsWith(s, suf ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
sufStr, ok := suf.(String)
if !ok {
return MaybeNoSuchOverloadErr(suf)
}
return Bool(strings.HasSuffix(string(s), string(sufStr)))
return Bool(strings.HasSuffix(string(str), string(sufStr)))
}
func stringStartsWith(s String, pre ref.Val) ref.Val {
// StringStartsWith returns whether the target string contains the input prefix.
func StringStartsWith(s, pre ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
preStr, ok := pre.(String)
if !ok {
return MaybeNoSuchOverloadErr(pre)
}
return Bool(strings.HasPrefix(string(s), string(preStr)))
return Bool(strings.HasPrefix(string(str), string(preStr)))
}

View File

@ -23,7 +23,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -53,15 +52,6 @@ const (
maxUnixTime int64 = 253402300799
)
var (
// TimestampType singleton.
TimestampType = NewTypeValue("google.protobuf.Timestamp",
traits.AdderType,
traits.ComparerType,
traits.ReceiverType,
traits.SubtractorType)
)
// Add implements traits.Adder.Add.
func (t Timestamp) Add(other ref.Val) ref.Val {
switch other.Type() {

View File

@ -1,102 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
var (
// TypeType is the type of a TypeValue.
TypeType = NewTypeValue("type")
)
// TypeValue is an instance of a Value that describes a value's type.
type TypeValue struct {
name string
traitMask int
}
// NewTypeValue returns *TypeValue which is both a ref.Type and ref.Val.
func NewTypeValue(name string, traits ...int) *TypeValue {
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &TypeValue{
name: name,
traitMask: traitMask}
}
// NewObjectTypeValue returns a *TypeValue based on the input name, which is
// annotated with the traits relevant to all objects.
func NewObjectTypeValue(name string) *TypeValue {
return NewTypeValue(name,
traits.FieldTesterType,
traits.IndexerType)
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *TypeValue) ConvertToNative(typeDesc reflect.Type) (any, error) {
// TODO: replace the internal type representation with a proto-value.
return nil, fmt.Errorf("type conversion not supported for 'type'")
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *TypeValue) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case TypeType:
return TypeType
case StringType:
return String(t.TypeName())
}
return NewErr("type conversion error from '%s' to '%s'", TypeType, typeVal)
}
// Equal implements ref.Val.Equal.
func (t *TypeValue) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait indicates whether the type supports the given trait.
// Trait codes are defined in the traits package, e.g. see traits.AdderType.
func (t *TypeValue) HasTrait(trait int) bool {
return trait&t.traitMask == trait
}
// String implements fmt.Stringer.
func (t *TypeValue) String() string {
return t.name
}
// Type implements ref.Val.Type.
func (t *TypeValue) Type() ref.Type {
return TypeType
}
// TypeName gives the type's name as a string.
func (t *TypeValue) TypeName() string {
return t.name
}
// Value implements ref.Val.Value.
func (t *TypeValue) Value() any {
return t.name
}

806
vendor/github.com/google/cel-go/common/types/types.go generated vendored Normal file
View File

@ -0,0 +1,806 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"fmt"
"reflect"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Kind indicates a CEL type's kind which is used to differentiate quickly between simple
// and complex types.
type Kind uint
const (
// UnspecifiedKind is returned when the type is nil or its kind is not specified.
UnspecifiedKind Kind = iota
// DynKind represents a dynamic type. This kind only exists at type-check time.
DynKind
// AnyKind represents a google.protobuf.Any type. This kind only exists at type-check time.
// Prefer DynKind to AnyKind as AnyKind has a specific meaning which is based on protobuf
// well-known types.
AnyKind
// BoolKind represents a boolean type.
BoolKind
// BytesKind represents a bytes type.
BytesKind
// DoubleKind represents a double type.
DoubleKind
// DurationKind represents a CEL duration type.
DurationKind
// ErrorKind represents a CEL error type.
ErrorKind
// IntKind represents an integer type.
IntKind
// ListKind represents a list type.
ListKind
// MapKind represents a map type.
MapKind
// NullTypeKind represents a null type.
NullTypeKind
// OpaqueKind represents an abstract type which has no accessible fields.
OpaqueKind
// StringKind represents a string type.
StringKind
// StructKind represents a structured object with typed fields.
StructKind
// TimestampKind represents a a CEL time type.
TimestampKind
// TypeKind represents the CEL type.
TypeKind
// TypeParamKind represents a parameterized type whose type name will be resolved at type-check time, if possible.
TypeParamKind
// UintKind represents a uint type.
UintKind
// UnknownKind represents an unknown value type.
UnknownKind
)
var (
// AnyType represents the google.protobuf.Any type.
AnyType = &Type{
kind: AnyKind,
runtimeTypeName: "google.protobuf.Any",
traitMask: traits.FieldTesterType |
traits.IndexerType,
}
// BoolType represents the bool type.
BoolType = &Type{
kind: BoolKind,
runtimeTypeName: "bool",
traitMask: traits.ComparerType |
traits.NegatorType,
}
// BytesType represents the bytes type.
BytesType = &Type{
kind: BytesKind,
runtimeTypeName: "bytes",
traitMask: traits.AdderType |
traits.ComparerType |
traits.SizerType,
}
// DoubleType represents the double type.
DoubleType = &Type{
kind: DoubleKind,
runtimeTypeName: "double",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.MultiplierType |
traits.NegatorType |
traits.SubtractorType,
}
// DurationType represents the CEL duration type.
DurationType = &Type{
kind: DurationKind,
runtimeTypeName: "google.protobuf.Duration",
traitMask: traits.AdderType |
traits.ComparerType |
traits.NegatorType |
traits.ReceiverType |
traits.SubtractorType,
}
// DynType represents a dynamic CEL type whose type will be determined at runtime from context.
DynType = &Type{
kind: DynKind,
runtimeTypeName: "dyn",
}
// ErrorType represents a CEL error value.
ErrorType = &Type{
kind: ErrorKind,
runtimeTypeName: "error",
}
// IntType represents the int type.
IntType = &Type{
kind: IntKind,
runtimeTypeName: "int",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.ModderType |
traits.MultiplierType |
traits.NegatorType |
traits.SubtractorType,
}
// ListType represents the runtime list type.
ListType = NewListType(nil)
// MapType represents the runtime map type.
MapType = NewMapType(nil, nil)
// NullType represents the type of a null value.
NullType = &Type{
kind: NullTypeKind,
runtimeTypeName: "null_type",
}
// StringType represents the string type.
StringType = &Type{
kind: StringKind,
runtimeTypeName: "string",
traitMask: traits.AdderType |
traits.ComparerType |
traits.MatcherType |
traits.ReceiverType |
traits.SizerType,
}
// TimestampType represents the time type.
TimestampType = &Type{
kind: TimestampKind,
runtimeTypeName: "google.protobuf.Timestamp",
traitMask: traits.AdderType |
traits.ComparerType |
traits.ReceiverType |
traits.SubtractorType,
}
// TypeType represents a CEL type
TypeType = &Type{
kind: TypeKind,
runtimeTypeName: "type",
}
// UintType represents a uint type.
UintType = &Type{
kind: UintKind,
runtimeTypeName: "uint",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.ModderType |
traits.MultiplierType |
traits.SubtractorType,
}
// UnknownType represents an unknown value type.
UnknownType = &Type{
kind: UnknownKind,
runtimeTypeName: "unknown",
}
)
var _ ref.Type = &Type{}
var _ ref.Val = &Type{}
// Type holds a reference to a runtime type with an optional type-checked set of type parameters.
type Type struct {
// kind indicates general category of the type.
kind Kind
// parameters holds the optional type-checked set of type Parameters that are used during static analysis.
parameters []*Type
// runtimeTypeName indicates the runtime type name of the type.
runtimeTypeName string
// isAssignableType function determines whether one type is assignable to this type.
// A nil value for the isAssignableType function falls back to equality of kind, runtimeType, and parameters.
isAssignableType func(other *Type) bool
// isAssignableRuntimeType function determines whether the runtime type (with erasure) is assignable to this type.
// A nil value for the isAssignableRuntimeType function falls back to the equality of the type or type name.
isAssignableRuntimeType func(other ref.Val) bool
// traitMask is a mask of flags which indicate the capabilities of the type.
traitMask int
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *Type) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion not supported for 'type'")
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *Type) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case TypeType:
return TypeType
case StringType:
return String(t.TypeName())
}
return NewErr("type conversion error from '%s' to '%s'", TypeType, typeVal)
}
// Equal indicates whether two types have the same runtime type name.
//
// The name Equal is a bit of a misnomer, but for historical reasons, this is the
// runtime behavior. For a more accurate definition see IsType().
func (t *Type) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait implements the ref.Type interface method.
func (t *Type) HasTrait(trait int) bool {
return trait&t.traitMask == trait
}
// IsExactType indicates whether the two types are exactly the same. This check also verifies type parameter type names.
func (t *Type) IsExactType(other *Type) bool {
return t.isTypeInternal(other, true)
}
// IsEquivalentType indicates whether two types are equivalent. This check ignores type parameter type names.
func (t *Type) IsEquivalentType(other *Type) bool {
return t.isTypeInternal(other, false)
}
// Kind indicates general category of the type.
func (t *Type) Kind() Kind {
if t == nil {
return UnspecifiedKind
}
return t.kind
}
// isTypeInternal checks whether the two types are equivalent or exactly the same based on the checkTypeParamName flag.
func (t *Type) isTypeInternal(other *Type, checkTypeParamName bool) bool {
if t == nil {
return false
}
if t == other {
return true
}
if t.Kind() != other.Kind() || len(t.Parameters()) != len(other.Parameters()) {
return false
}
if (checkTypeParamName || t.Kind() != TypeParamKind) && t.TypeName() != other.TypeName() {
return false
}
for i, p := range t.Parameters() {
if !p.isTypeInternal(other.Parameters()[i], checkTypeParamName) {
return false
}
}
return true
}
// IsAssignableType determines whether the current type is type-check assignable from the input fromType.
func (t *Type) IsAssignableType(fromType *Type) bool {
if t == nil {
return false
}
if t.isAssignableType != nil {
return t.isAssignableType(fromType)
}
return t.defaultIsAssignableType(fromType)
}
// IsAssignableRuntimeType determines whether the current type is runtime assignable from the input runtimeType.
//
// At runtime, parameterized types are erased and so a function which type-checks to support a map(string, string)
// will have a runtime assignable type of a map.
func (t *Type) IsAssignableRuntimeType(val ref.Val) bool {
if t == nil {
return false
}
if t.isAssignableRuntimeType != nil {
return t.isAssignableRuntimeType(val)
}
return t.defaultIsAssignableRuntimeType(val)
}
// Parameters returns the list of type parameters if set.
//
// For ListKind, Parameters()[0] represents the list element type
// For MapKind, Parameters()[0] represents the map key type, and Parameters()[1] represents the map
// value type.
func (t *Type) Parameters() []*Type {
if t == nil {
return emptyParams
}
return t.parameters
}
// DeclaredTypeName indicates the fully qualified and parameterized type-check type name.
func (t *Type) DeclaredTypeName() string {
// if the type itself is neither null, nor dyn, but is assignable to null, then it's a wrapper type.
if t.Kind() != NullTypeKind && !t.isDyn() && t.IsAssignableType(NullType) {
return fmt.Sprintf("wrapper(%s)", t.TypeName())
}
return t.TypeName()
}
// Type implements the ref.Val interface method.
func (t *Type) Type() ref.Type {
return TypeType
}
// Value implements the ref.Val interface method.
func (t *Type) Value() any {
return t.TypeName()
}
// TypeName returns the type-erased fully qualified runtime type name.
//
// TypeName implements the ref.Type interface method.
func (t *Type) TypeName() string {
if t == nil {
return ""
}
return t.runtimeTypeName
}
// String returns a human-readable definition of the type name.
func (t *Type) String() string {
if len(t.Parameters()) == 0 {
return t.DeclaredTypeName()
}
params := make([]string, len(t.Parameters()))
for i, p := range t.Parameters() {
params[i] = p.String()
}
return fmt.Sprintf("%s(%s)", t.DeclaredTypeName(), strings.Join(params, ", "))
}
// isDyn indicates whether the type is dynamic in any way.
func (t *Type) isDyn() bool {
k := t.Kind()
return k == DynKind || k == AnyKind || k == TypeParamKind
}
// defaultIsAssignableType provides the standard definition of what it means for one type to be assignable to another
// where any of the following may return a true result:
// - The from types are the same instance
// - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type
//
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() {
return true
}
if t.Kind() != fromType.Kind() ||
t.TypeName() != fromType.TypeName() ||
len(t.Parameters()) != len(fromType.Parameters()) {
return false
}
for i, tp := range t.Parameters() {
fp := fromType.Parameters()[i]
if !tp.IsAssignableType(fp) {
return false
}
}
return true
}
// defaultIsAssignableRuntimeType inspects the type and in the case of list and map elements, the key and element types
// to determine whether a ref.Val is assignable to the declared type for a function signature.
func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool {
valType := val.Type()
// If the current type and value type don't agree, then return
if !(t.isDyn() || t.TypeName() == valType.TypeName()) {
return false
}
switch t.Kind() {
case ListKind:
elemType := t.Parameters()[0]
l := val.(traits.Lister)
if l.Size() == IntZero {
return true
}
it := l.Iterator()
elemVal := it.Next()
return elemType.IsAssignableRuntimeType(elemVal)
case MapKind:
keyType := t.Parameters()[0]
elemType := t.Parameters()[1]
m := val.(traits.Mapper)
if m.Size() == IntZero {
return true
}
it := m.Iterator()
keyVal := it.Next()
elemVal := m.Get(keyVal)
return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal)
}
return true
}
// NewListType creates an instances of a list type value with the provided element type.
func NewListType(elemType *Type) *Type {
return &Type{
kind: ListKind,
parameters: []*Type{elemType},
runtimeTypeName: "list",
traitMask: traits.AdderType |
traits.ContainerType |
traits.IndexerType |
traits.IterableType |
traits.SizerType,
}
}
// NewMapType creates an instance of a map type value with the provided key and value types.
func NewMapType(keyType, valueType *Type) *Type {
return &Type{
kind: MapKind,
parameters: []*Type{keyType, valueType},
runtimeTypeName: "map",
traitMask: traits.ContainerType |
traits.IndexerType |
traits.IterableType |
traits.SizerType,
}
}
// NewNullableType creates an instance of a nullable type with the provided wrapped type.
//
// Note: only primitive types are supported as wrapped types.
func NewNullableType(wrapped *Type) *Type {
return &Type{
kind: wrapped.Kind(),
parameters: wrapped.Parameters(),
runtimeTypeName: wrapped.TypeName(),
traitMask: wrapped.traitMask,
isAssignableType: func(other *Type) bool {
return NullType.IsAssignableType(other) || wrapped.IsAssignableType(other)
},
isAssignableRuntimeType: func(other ref.Val) bool {
return NullType.IsAssignableRuntimeType(other) || wrapped.IsAssignableRuntimeType(other)
},
}
}
// NewOptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func NewOptionalType(param *Type) *Type {
return NewOpaqueType("optional", param)
}
// NewOpaqueType creates an abstract parameterized type with a given name.
func NewOpaqueType(name string, params ...*Type) *Type {
return &Type{
kind: OpaqueKind,
parameters: params,
runtimeTypeName: name,
}
}
// NewObjectType creates a type reference to an externally defined type, e.g. a protobuf message type.
//
// An object type is assumed to support field presence testing and field indexing. Additionally, the
// type may also indicate additional traits through the use of the optional traits vararg argument.
func NewObjectType(typeName string, traits ...int) *Type {
// Function sanitizes object types on the fly
if wkt, found := checkedWellKnowns[typeName]; found {
return wkt
}
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &Type{
kind: StructKind,
parameters: emptyParams,
runtimeTypeName: typeName,
traitMask: structTypeTraitMask | traitMask,
}
}
// NewObjectTypeValue creates a type reference to an externally defined type.
//
// Deprecated: use cel.ObjectType(typeName)
func NewObjectTypeValue(typeName string) *Type {
return NewObjectType(typeName)
}
// NewTypeValue creates an opaque type which has a set of optional type traits as defined in
// the common/types/traits package.
//
// Deprecated: use cel.ObjectType(typeName, traits)
func NewTypeValue(typeName string, traits ...int) *Type {
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &Type{
kind: StructKind,
parameters: emptyParams,
runtimeTypeName: typeName,
traitMask: traitMask,
}
}
// NewTypeParamType creates a parameterized type instance.
func NewTypeParamType(paramName string) *Type {
return &Type{
kind: TypeParamKind,
runtimeTypeName: paramName,
}
}
// NewTypeTypeWithParam creates a type with a type parameter.
// Used for type-checking purposes, but equivalent to TypeType otherwise.
func NewTypeTypeWithParam(param *Type) *Type {
return &Type{
kind: TypeKind,
runtimeTypeName: "type",
parameters: []*Type{param},
}
}
// TypeToExprType converts a CEL-native type representation to a protobuf CEL Type representation.
func TypeToExprType(t *Type) (*exprpb.Type, error) {
switch t.Kind() {
case AnyKind:
return chkdecls.Any, nil
case BoolKind:
return maybeWrapper(t, chkdecls.Bool), nil
case BytesKind:
return maybeWrapper(t, chkdecls.Bytes), nil
case DoubleKind:
return maybeWrapper(t, chkdecls.Double), nil
case DurationKind:
return chkdecls.Duration, nil
case DynKind:
return chkdecls.Dyn, nil
case ErrorKind:
return chkdecls.Error, nil
case IntKind:
return maybeWrapper(t, chkdecls.Int), nil
case ListKind:
if len(t.Parameters()) != 1 {
return nil, fmt.Errorf("invalid list, got %d parameters, wanted one", len(t.Parameters()))
}
et, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
return chkdecls.NewListType(et), nil
case MapKind:
if len(t.Parameters()) != 2 {
return nil, fmt.Errorf("invalid map, got %d parameters, wanted two", len(t.Parameters()))
}
kt, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
vt, err := TypeToExprType(t.Parameters()[1])
if err != nil {
return nil, err
}
return chkdecls.NewMapType(kt, vt), nil
case NullTypeKind:
return chkdecls.Null, nil
case OpaqueKind:
params := make([]*exprpb.Type, len(t.Parameters()))
for i, p := range t.Parameters() {
pt, err := TypeToExprType(p)
if err != nil {
return nil, err
}
params[i] = pt
}
return chkdecls.NewAbstractType(t.TypeName(), params...), nil
case StringKind:
return maybeWrapper(t, chkdecls.String), nil
case StructKind:
return chkdecls.NewObjectType(t.TypeName()), nil
case TimestampKind:
return chkdecls.Timestamp, nil
case TypeParamKind:
return chkdecls.NewTypeParamType(t.TypeName()), nil
case TypeKind:
if len(t.Parameters()) == 1 {
p, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
return chkdecls.NewTypeType(p), nil
}
return chkdecls.NewTypeType(nil), nil
case UintKind:
return maybeWrapper(t, chkdecls.Uint), nil
}
return nil, fmt.Errorf("missing type conversion to proto: %v", t)
}
// ExprTypeToType converts a protobuf CEL type representation to a CEL-native type representation.
func ExprTypeToType(t *exprpb.Type) (*Type, error) {
switch t.GetTypeKind().(type) {
case *exprpb.Type_Dyn:
return DynType, nil
case *exprpb.Type_AbstractType_:
paramTypes := make([]*Type, len(t.GetAbstractType().GetParameterTypes()))
for i, p := range t.GetAbstractType().GetParameterTypes() {
pt, err := ExprTypeToType(p)
if err != nil {
return nil, err
}
paramTypes[i] = pt
}
return NewOpaqueType(t.GetAbstractType().GetName(), paramTypes...), nil
case *exprpb.Type_ListType_:
et, err := ExprTypeToType(t.GetListType().GetElemType())
if err != nil {
return nil, err
}
return NewListType(et), nil
case *exprpb.Type_MapType_:
kt, err := ExprTypeToType(t.GetMapType().GetKeyType())
if err != nil {
return nil, err
}
vt, err := ExprTypeToType(t.GetMapType().GetValueType())
if err != nil {
return nil, err
}
return NewMapType(kt, vt), nil
case *exprpb.Type_MessageType:
return NewObjectType(t.GetMessageType()), nil
case *exprpb.Type_Null:
return NullType, nil
case *exprpb.Type_Primitive:
switch t.GetPrimitive() {
case exprpb.Type_BOOL:
return BoolType, nil
case exprpb.Type_BYTES:
return BytesType, nil
case exprpb.Type_DOUBLE:
return DoubleType, nil
case exprpb.Type_INT64:
return IntType, nil
case exprpb.Type_STRING:
return StringType, nil
case exprpb.Type_UINT64:
return UintType, nil
default:
return nil, fmt.Errorf("unsupported primitive type: %v", t)
}
case *exprpb.Type_TypeParam:
return NewTypeParamType(t.GetTypeParam()), nil
case *exprpb.Type_Type:
if t.GetType().GetTypeKind() != nil {
p, err := ExprTypeToType(t.GetType())
if err != nil {
return nil, err
}
return NewTypeTypeWithParam(p), nil
}
return TypeType, nil
case *exprpb.Type_WellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return AnyType, nil
case exprpb.Type_DURATION:
return DurationType, nil
case exprpb.Type_TIMESTAMP:
return TimestampType, nil
default:
return nil, fmt.Errorf("unsupported well-known type: %v", t)
}
case *exprpb.Type_Wrapper:
t, err := ExprTypeToType(&exprpb.Type{TypeKind: &exprpb.Type_Primitive{Primitive: t.GetWrapper()}})
if err != nil {
return nil, err
}
return NewNullableType(t), nil
case *exprpb.Type_Error:
return ErrorType, nil
default:
return nil, fmt.Errorf("unsupported type: %v", t)
}
}
func maybeWrapper(t *Type, pbType *exprpb.Type) *exprpb.Type {
if t.IsAssignableType(NullType) {
return chkdecls.NewWrapperType(pbType)
}
return pbType
}
func maybeForeignType(t ref.Type) *Type {
if celType, ok := t.(*Type); ok {
return celType
}
// Inspect the incoming type to determine its traits. The assumption will be that the incoming
// type does not have any field values; however, if the trait mask indicates that field testing
// and indexing are supported, the foreign type is marked as a struct.
traitMask := 0
for _, trait := range allTraits {
if t.HasTrait(trait) {
traitMask |= trait
}
}
// Treat the value like a struct. If it has no fields, this is harmless to denote the type
// as such since it basically becomes an opaque type by convention.
return NewObjectType(t.TypeName(), traitMask)
}
var (
checkedWellKnowns = map[string]*Type{
// Wrapper types.
"google.protobuf.BoolValue": NewNullableType(BoolType),
"google.protobuf.BytesValue": NewNullableType(BytesType),
"google.protobuf.DoubleValue": NewNullableType(DoubleType),
"google.protobuf.FloatValue": NewNullableType(DoubleType),
"google.protobuf.Int64Value": NewNullableType(IntType),
"google.protobuf.Int32Value": NewNullableType(IntType),
"google.protobuf.UInt64Value": NewNullableType(UintType),
"google.protobuf.UInt32Value": NewNullableType(UintType),
"google.protobuf.StringValue": NewNullableType(StringType),
// Well-known types.
"google.protobuf.Any": AnyType,
"google.protobuf.Duration": DurationType,
"google.protobuf.Timestamp": TimestampType,
// Json types.
"google.protobuf.ListValue": NewListType(DynType),
"google.protobuf.NullValue": NullType,
"google.protobuf.Struct": NewMapType(StringType, DynType),
"google.protobuf.Value": DynType,
}
emptyParams = []*Type{}
allTraits = []int{
traits.AdderType,
traits.ComparerType,
traits.ContainerType,
traits.DividerType,
traits.FieldTesterType,
traits.IndexerType,
traits.IterableType,
traits.IteratorType,
traits.MatcherType,
traits.ModderType,
traits.MultiplierType,
traits.NegatorType,
traits.ReceiverType,
traits.SizerType,
traits.SubtractorType,
}
structTypeTraitMask = traits.FieldTesterType | traits.IndexerType
)

View File

@ -21,7 +21,6 @@ import (
"strconv"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -32,15 +31,6 @@ import (
type Uint uint64
var (
// UintType singleton.
UintType = NewTypeValue("uint",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.ModderType,
traits.MultiplierType,
traits.SubtractorType)
uint32WrapperType = reflect.TypeOf(&wrapperspb.UInt32Value{})
uint64WrapperType = reflect.TypeOf(&wrapperspb.UInt64Value{})

View File

@ -15,52 +15,312 @@
package types
import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"github.com/google/cel-go/common/types/ref"
)
// Unknown type implementation which collects expression ids which caused the
// current value to become unknown.
type Unknown []int64
var (
// UnknownType singleton.
UnknownType = NewTypeValue("unknown")
unspecifiedAttribute = &AttributeTrail{qualifierPath: []any{}}
)
// NewAttributeTrail creates a new simple attribute from a variable name.
func NewAttributeTrail(variable string) *AttributeTrail {
if variable == "" {
return unspecifiedAttribute
}
return &AttributeTrail{variable: variable}
}
// AttributeTrail specifies a variable with an optional qualifier path. An attribute value is expected to
// correspond to an AbsoluteAttribute, meaning a field selection which starts with a top-level variable.
//
// The qualifer path elements adhere to the AttributeQualifier type constraint.
type AttributeTrail struct {
variable string
qualifierPath []any
}
// Equal returns whether two attribute values have the same variable name and qualifier paths.
func (a *AttributeTrail) Equal(other *AttributeTrail) bool {
if a.Variable() != other.Variable() || len(a.QualifierPath()) != len(other.QualifierPath()) {
return false
}
for i, q := range a.QualifierPath() {
qual := other.QualifierPath()[i]
if !qualifiersEqual(q, qual) {
return false
}
}
return true
}
func qualifiersEqual(a, b any) bool {
if a == b {
return true
}
switch numA := a.(type) {
case int64:
numB, ok := b.(uint64)
if !ok {
return false
}
return intUintEqual(numA, numB)
case uint64:
numB, ok := b.(int64)
if !ok {
return false
}
return intUintEqual(numB, numA)
default:
return false
}
}
func intUintEqual(i int64, u uint64) bool {
if i < 0 || u > math.MaxInt64 {
return false
}
return i == int64(u)
}
// Variable returns the variable name associated with the attribute.
func (a *AttributeTrail) Variable() string {
return a.variable
}
// QualifierPath returns the optional set of qualifying fields or indices applied to the variable.
func (a *AttributeTrail) QualifierPath() []any {
return a.qualifierPath
}
// String returns the string representation of the Attribute.
func (a *AttributeTrail) String() string {
if a.variable == "" {
return "<unspecified>"
}
var str strings.Builder
str.WriteString(a.variable)
for _, q := range a.qualifierPath {
switch q := q.(type) {
case bool, int64:
str.WriteString(fmt.Sprintf("[%v]", q))
case uint64:
str.WriteString(fmt.Sprintf("[%vu]", q))
case string:
if isIdentifierCharacter(q) {
str.WriteString(fmt.Sprintf(".%v", q))
} else {
str.WriteString(fmt.Sprintf("[%q]", q))
}
}
}
return str.String()
}
func isIdentifierCharacter(str string) bool {
for _, c := range str {
if unicode.IsLetter(c) || unicode.IsDigit(c) || string(c) == "_" {
continue
}
return false
}
return true
}
// AttributeQualifier constrains the possible types which may be used to qualify an attribute.
type AttributeQualifier interface {
bool | int64 | uint64 | string
}
// QualifyAttribute qualifies an attribute using a valid AttributeQualifier type.
func QualifyAttribute[T AttributeQualifier](attr *AttributeTrail, qualifier T) *AttributeTrail {
attr.qualifierPath = append(attr.qualifierPath, qualifier)
return attr
}
// Unknown type which collects expression ids which caused the current value to become unknown.
type Unknown struct {
attributeTrails map[int64][]*AttributeTrail
}
// NewUnknown creates a new unknown at a given expression id for an attribute.
//
// If the attribute is nil, the attribute value will be the `unspecifiedAttribute`.
func NewUnknown(id int64, attr *AttributeTrail) *Unknown {
if attr == nil {
attr = unspecifiedAttribute
}
return &Unknown{
attributeTrails: map[int64][]*AttributeTrail{id: {attr}},
}
}
// IDs returns the set of unknown expression ids contained by this value.
//
// Numeric identifiers are guaranteed to be in sorted order.
func (u *Unknown) IDs() []int64 {
ids := make(int64Slice, len(u.attributeTrails))
i := 0
for id := range u.attributeTrails {
ids[i] = id
i++
}
ids.Sort()
return ids
}
// GetAttributeTrails returns the attribute trails, if present, missing for a given expression id.
func (u *Unknown) GetAttributeTrails(id int64) ([]*AttributeTrail, bool) {
trails, found := u.attributeTrails[id]
return trails, found
}
// Contains returns true if the input unknown is a subset of the current unknown.
func (u *Unknown) Contains(other *Unknown) bool {
for id, otherTrails := range other.attributeTrails {
trails, found := u.attributeTrails[id]
if !found || len(otherTrails) != len(trails) {
return false
}
for _, ot := range otherTrails {
found := false
for _, t := range trails {
if t.Equal(ot) {
found = true
break
}
}
if !found {
return false
}
}
}
return true
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (u Unknown) ConvertToNative(typeDesc reflect.Type) (any, error) {
func (u *Unknown) ConvertToNative(typeDesc reflect.Type) (any, error) {
return u.Value(), nil
}
// ConvertToType is an identity function since unknown values cannot be modified.
func (u Unknown) ConvertToType(typeVal ref.Type) ref.Val {
func (u *Unknown) ConvertToType(typeVal ref.Type) ref.Val {
return u
}
// Equal is an identity function since unknown values cannot be modified.
func (u Unknown) Equal(other ref.Val) ref.Val {
func (u *Unknown) Equal(other ref.Val) ref.Val {
return u
}
// String implements the Stringer interface
func (u *Unknown) String() string {
var str strings.Builder
for id, attrs := range u.attributeTrails {
if str.Len() != 0 {
str.WriteString(", ")
}
if len(attrs) == 1 {
str.WriteString(fmt.Sprintf("%v (%d)", attrs[0], id))
} else {
str.WriteString(fmt.Sprintf("%v (%d)", attrs, id))
}
}
return str.String()
}
// Type implements ref.Val.Type.
func (u Unknown) Type() ref.Type {
func (u *Unknown) Type() ref.Type {
return UnknownType
}
// Value implements ref.Val.Value.
func (u Unknown) Value() any {
return []int64(u)
func (u *Unknown) Value() any {
return u
}
// IsUnknown returns whether the element ref.Type or ref.Val is equal to the
// UnknownType singleton.
// IsUnknown returns whether the element ref.Val is in instance of *types.Unknown
func IsUnknown(val ref.Val) bool {
switch val.(type) {
case Unknown:
case *Unknown:
return true
default:
return false
}
}
// MaybeMergeUnknowns determines whether an input value and another, possibly nil, unknown will produce
// an unknown result.
//
// If the input `val` is another Unknown, then the result will be the merge of the `val` and the input
// `unk`. If the `val` is not unknown, then the result will depend on whether the input `unk` is nil.
// If both values are non-nil and unknown, then the return value will be a merge of both unknowns.
func MaybeMergeUnknowns(val ref.Val, unk *Unknown) (*Unknown, bool) {
src, isUnk := val.(*Unknown)
if !isUnk {
if unk != nil {
return unk, true
}
return unk, false
}
return MergeUnknowns(src, unk), true
}
// MergeUnknowns combines two unknown values into a new unknown value.
func MergeUnknowns(unk1, unk2 *Unknown) *Unknown {
if unk1 == nil {
return unk2
}
if unk2 == nil {
return unk1
}
out := &Unknown{
attributeTrails: make(map[int64][]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)),
}
for id, ats := range unk1.attributeTrails {
out.attributeTrails[id] = ats
}
for id, ats := range unk2.attributeTrails {
existing, found := out.attributeTrails[id]
if !found {
out.attributeTrails[id] = ats
continue
}
for _, at := range ats {
found := false
for _, et := range existing {
if at.Equal(et) {
found = true
break
}
}
if !found {
existing = append(existing, at)
}
}
out.attributeTrails[id] = existing
}
return out
}
// int64Slice is an implementation of the sort.Interface
type int64Slice []int64
// Len returns the number of elements in the slice.
func (x int64Slice) Len() int { return len(x) }
// Less indicates whether the value at index i is less than the value at index j.
func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] }
// Swap swaps the values at indices i and j in place.
func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
// Sort is a convenience method: x.Sort() calls Sort(x).
func (x int64Slice) Sort() { sort.Sort(x) }

View File

@ -21,7 +21,7 @@ import (
// IsUnknownOrError returns whether the input element ref.Val is an ErrType or UnknownType.
func IsUnknownOrError(val ref.Val) bool {
switch val.(type) {
case Unknown, *Err:
case *Unknown, *Err:
return true
}
return false

View File

@ -9,6 +9,7 @@ go_library(
srcs = [
"encoders.go",
"guards.go",
"lists.go",
"math.go",
"native.go",
"protos.go",
@ -19,8 +20,8 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
@ -41,6 +42,7 @@ go_test(
size = "small",
srcs = [
"encoders_test.go",
"lists_test.go",
"math_test.go",
"native_test.go",
"protos_test.go",
@ -53,7 +55,6 @@ go_test(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",

View File

@ -149,6 +149,23 @@ Example:
proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false
## Lists
Extended functions for list manipulation. As a general note, all indices are
zero-based.
### Slice
Returns a new sub-list using the indexes provided.
<list>.slice(<int>, <int>) -> <list>
Examples:
[1,2,3,4].slice(1, 3) // return [2, 3]
[1,2,3,4].slice(2, 4) // return [3 ,4]
## Sets
Sets provides set relationship tests.

View File

@ -16,7 +16,6 @@ package ext
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -71,7 +70,7 @@ func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
@ -81,10 +80,7 @@ func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr)
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "cel.bind() variable names must be simple identifers",
Location: meh.OffsetLocation(varIdent.GetId()),
}
return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers")
}
varInit := args[1]
resultExpr := args[2]

View File

@ -16,7 +16,6 @@ package ext
import (
"encoding/base64"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
@ -86,7 +85,3 @@ func base64DecodeString(str string) ([]byte, error) {
func base64EncodeBytes(bytes []byte) (string, error) {
return base64.StdEncoding.EncodeToString(bytes), nil
}
var (
bytesListType = reflect.TypeOf([]byte{})
)

View File

@ -17,6 +17,7 @@ package ext
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

94
vendor/github.com/google/cel-go/ext/lists.go generated vendored Normal file
View File

@ -0,0 +1,94 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// Lists returns a cel.EnvOption to configure extended functions for list manipulation.
// As a general note, all indices are zero-based.
// # Slice
//
// Returns a new sub-list using the indexes provided.
//
// <list>.slice(<int>, <int>) -> <list>
//
// Examples:
//
// [1,2,3,4].slice(1, 3) // return [2, 3]
// [1,2,3,4].slice(2, 4) // return [3 ,4]
func Lists() cel.EnvOption {
return cel.Lib(listsLib{})
}
type listsLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (listsLib) LibraryName() string {
return "cel.lib.ext.lists"
}
// CompileOptions implements the Library interface method.
func (listsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
return []cel.EnvOption{
cel.Function("slice",
cel.MemberOverload("list_slice",
[]*cel.Type{listType, cel.IntType, cel.IntType}, listType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
list := args[0].(traits.Lister)
start := args[1].(types.Int)
end := args[2].(types.Int)
result, err := slice(list, start, end)
if err != nil {
return types.WrapErr(err)
}
return result
}),
),
),
}
}
// ProgramOptions implements the Library interface method.
func (listsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func slice(list traits.Lister, start, end types.Int) (ref.Val, error) {
listLength := list.Size().(types.Int)
if start < 0 || end < 0 {
return nil, fmt.Errorf("cannot slice(%d, %d), negative indexes not supported", start, end)
}
if start > end {
return nil, fmt.Errorf("cannot slice(%d, %d), start index must be less than or equal to end index", start, end)
}
if listLength < end {
return nil, fmt.Errorf("cannot slice(%d, %d), list is length %d", start, end, listLength)
}
var newList []ref.Val
for i := types.Int(start); i < end; i++ {
val := list.Get(i)
newList = append(newList, val)
}
return types.DefaultTypeAdapter.NativeToValue(newList), nil
}

View File

@ -19,10 +19,10 @@ import (
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -187,24 +187,18 @@ func (mathLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, &common.Error{
Message: "math.least() requires at least one argument",
Location: meh.OffsetLocation(target.GetId()),
}
return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(minFunc, args[0]), nil
}
return nil, &common.Error{
Message: "math.least() invalid single argument value",
Location: meh.OffsetLocation(args[0].GetId()),
}
return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
@ -220,24 +214,18 @@ func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr
}
}
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, &common.Error{
Message: "math.greatest() requires at least one argument",
Location: meh.OffsetLocation(target.GetId()),
}
return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(maxFunc, args[0]), nil
}
return nil, &common.Error{
Message: "math.greatest() invalid single argument value",
Location: meh.OffsetLocation(args[0].GetId()),
}
return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.greatest()", args)
if err != nil {
@ -323,14 +311,11 @@ func maxList(numList ref.Val) ref.Val {
}
}
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *common.Error {
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error {
for _, arg := range args {
err := checkInvalidArgLiteral(funcName, arg)
if err != nil {
return &common.Error{
Message: err.Error(),
Location: meh.OffsetLocation(arg.GetId()),
}
return meh.NewError(arg.GetId(), err.Error())
}
}
return nil

View File

@ -24,13 +24,11 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@ -81,7 +79,7 @@ var (
// the time that it is invoked.
func NativeTypes(refTypes ...any) cel.EnvOption {
return func(env *cel.Env) (*cel.Env, error) {
tp, err := newNativeTypeProvider(env.TypeAdapter(), env.TypeProvider(), refTypes...)
tp, err := newNativeTypeProvider(env.CELTypeAdapter(), env.CELTypeProvider(), refTypes...)
if err != nil {
return nil, err
}
@ -93,7 +91,7 @@ func NativeTypes(refTypes ...any) cel.EnvOption {
}
}
func newNativeTypeProvider(adapter ref.TypeAdapter, provider ref.TypeProvider, refTypes ...any) (*nativeTypeProvider, error) {
func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
nativeTypes := make(map[string]*nativeType, len(refTypes))
for _, refType := range refTypes {
switch rt := refType.(type) {
@ -122,18 +120,18 @@ func newNativeTypeProvider(adapter ref.TypeAdapter, provider ref.TypeProvider, r
type nativeTypeProvider struct {
nativeTypes map[string]*nativeType
baseAdapter ref.TypeAdapter
baseProvider ref.TypeProvider
baseAdapter types.Adapter
baseProvider types.Provider
}
// EnumValue proxies to the ref.TypeProvider configured at the times the NativeTypes
// EnumValue proxies to the types.Provider configured at the times the NativeTypes
// option was configured.
func (tp *nativeTypeProvider) EnumValue(enumName string) ref.Val {
return tp.baseProvider.EnumValue(enumName)
}
// FindIdent looks up natives type instances by qualified identifier, and if not found
// proxies to the composed ref.TypeProvider.
// proxies to the composed types.Provider.
func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) {
if t, found := tp.nativeTypes[typeName]; found {
return t, true
@ -141,32 +139,35 @@ func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) {
return tp.baseProvider.FindIdent(typeName)
}
// FindType looks up CEL type-checker type definition by qualified identifier, and if not found
// proxies to the composed ref.TypeProvider.
func (tp *nativeTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
// FindStructType looks up the CEL type definition by qualified identifier, and if not found
// proxies to the composed types.Provider.
func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if _, found := tp.nativeTypes[typeName]; found {
return decls.NewTypeType(decls.NewObjectType(typeName)), true
return types.NewTypeTypeWithParam(types.NewObjectType(typeName)), true
}
return tp.baseProvider.FindType(typeName)
if celType, found := tp.baseProvider.FindStructType(typeName); found {
return celType, true
}
return tp.baseProvider.FindStructType(typeName)
}
// FindFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed ref.TypeProvider
func (tp *nativeTypeProvider) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
t, found := tp.nativeTypes[typeName]
if !found {
return tp.baseProvider.FindFieldType(typeName, fieldName)
return tp.baseProvider.FindStructFieldType(typeName, fieldName)
}
refField, isDefined := t.hasField(fieldName)
if !found || !isDefined {
return nil, false
}
exprType, ok := convertToExprType(refField.Type)
celType, ok := convertToCelType(refField.Type)
if !ok {
return nil, false
}
return &ref.FieldType{
Type: exprType,
return &types.FieldType{
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
@ -243,75 +244,74 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
}
}
// convertToExprType converts the Golang reflect.Type to a protobuf exprpb.Type.
func convertToExprType(refType reflect.Type) (*exprpb.Type, bool) {
func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
switch refType.Kind() {
case reflect.Bool:
return decls.Bool, true
return cel.BoolType, true
case reflect.Float32, reflect.Float64:
return decls.Double, true
return cel.DoubleType, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if refType == durationType {
return decls.Duration, true
return cel.DurationType, true
}
return decls.Int, true
return cel.IntType, true
case reflect.String:
return decls.String, true
return cel.StringType, true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return decls.Uint, true
return cel.UintType, true
case reflect.Array, reflect.Slice:
refElem := refType.Elem()
if refElem == reflect.TypeOf(byte(0)) {
return decls.Bytes, true
return cel.BytesType, true
}
elemType, ok := convertToExprType(refElem)
elemType, ok := convertToCelType(refElem)
if !ok {
return nil, false
}
return decls.NewListType(elemType), true
return cel.ListType(elemType), true
case reflect.Map:
keyType, ok := convertToExprType(refType.Key())
keyType, ok := convertToCelType(refType.Key())
if !ok {
return nil, false
}
// Ensure the key type is a int, bool, uint, string
elemType, ok := convertToExprType(refType.Elem())
elemType, ok := convertToCelType(refType.Elem())
if !ok {
return nil, false
}
return decls.NewMapType(keyType, elemType), true
return cel.MapType(keyType, elemType), true
case reflect.Struct:
if refType == timestampType {
return decls.Timestamp, true
return cel.TimestampType, true
}
return decls.NewObjectType(
return cel.ObjectType(
fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
), true
case reflect.Pointer:
if refType.Implements(pbMsgInterfaceType) {
pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage)
return decls.NewObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
return cel.ObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
}
return convertToExprType(refType.Elem())
return convertToCelType(refType.Elem())
}
return nil, false
}
func newNativeObject(adapter ref.TypeAdapter, val any, refValue reflect.Value) ref.Val {
func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
return &nativeObj{
TypeAdapter: adapter,
val: val,
valType: valType,
refValue: refValue,
Adapter: adapter,
val: val,
valType: valType,
refValue: refValue,
}
}
type nativeObj struct {
ref.TypeAdapter
types.Adapter
val any
valType *nativeType
refValue reflect.Value
@ -520,11 +520,11 @@ func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
return f, true
}
func adaptFieldValue(adapter ref.TypeAdapter, refField reflect.Value) ref.Val {
func adaptFieldValue(adapter types.Adapter, refField reflect.Value) ref.Val {
return adapter.NativeToValue(getFieldValue(adapter, refField))
}
func getFieldValue(adapter ref.TypeAdapter, refField reflect.Value) any {
func getFieldValue(adapter types.Adapter, refField reflect.Value) any {
if refField.IsZero() {
switch refField.Kind() {
case reflect.Array, reflect.Slice:

View File

@ -16,7 +16,6 @@ package ext
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -86,7 +85,7 @@ func (protoLib) ProgramOptions() []cel.ProgramOption {
}
// hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message.
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
@ -98,7 +97,7 @@ func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Ex
}
// getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message.
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
@ -109,7 +108,7 @@ func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Ex
return meh.Select(args[0], extFieldName), nil
}
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *common.Error) {
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) {
isValid := false
extensionField := ""
switch expr.GetExprKind().(type) {
@ -117,10 +116,7 @@ func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *commo
extensionField, isValid = validateIdentifier(expr)
}
if !isValid {
return "", &common.Error{
Message: "invalid extension field",
Location: meh.OffsetLocation(expr.GetId()),
}
return "", meh.NewError(expr.GetId(), "invalid extension field")
}
return extensionField, nil
}

View File

@ -15,10 +15,14 @@
package ext
import (
"math"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
// Sets returns a cel.EnvOption to configure namespaced set relationship
@ -95,12 +99,24 @@ func (setsLib) CompileOptions() []cel.EnvOption {
cel.Function("sets.intersects",
cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsIntersects))),
cel.CostEstimatorOptions(
checker.OverloadCostEstimate("list_sets_contains_list", estimateSetsCost(1)),
checker.OverloadCostEstimate("list_sets_intersects_list", estimateSetsCost(1)),
// equivalence requires potentially two m*n comparisons to ensure each list is contained by the other
checker.OverloadCostEstimate("list_sets_equivalent_list", estimateSetsCost(2)),
),
}
}
// ProgramOptions implements the Library interface method.
func (setsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.OverloadCostTracker("list_sets_contains_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_intersects_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_equivalent_list", trackSetsCost(2)),
),
}
}
func setsIntersects(listA, listB ref.Val) ref.Val {
@ -136,3 +152,46 @@ func setsEquivalent(listA, listB ref.Val) ref.Val {
}
return setsContains(listB, listA)
}
func estimateSetsCost(costFactor float64) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
}
return nil
}
}
func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate {
if l := node.ComputedSize(); l != nil {
return *l
}
if l := estimator.EstimateSize(node); l != nil {
return *l
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}
func trackSetsCost(costFactor float64) interpreter.FunctionTracker {
return func(args []ref.Val, _ ref.Val) *uint64 {
lhsSize := actualSize(args[0])
rhsSize := actualSize(args[1])
cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor)
return &cost
}
}
func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
var (
callCostEstimate = checker.CostEstimate{Min: 1, Max: 1}
callCost = uint64(1)
)

View File

@ -173,7 +173,7 @@ const (
// 'TacoCat'.lowerAscii() // returns 'tacocat'
// 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
//
// # Quote
// # Strings.Quote
//
// Introduced in version: 1
//
@ -301,26 +301,28 @@ func StringsLocale(locale string) StringsOption {
}
}
// StringsVersion configures the version of the string library. The version limits which
// functions are available. Only functions introduced below or equal to the given
// version included in the library. See the library documentation to determine
// which version a function was introduced at. If the documentation does not
// state which version a function was introduced at, it can be assumed to be
// introduced at version 0, when the library was first created.
// If this option is not set, all functions are available.
func StringsVersion(version uint32) func(lib *stringLib) *stringLib {
return func(sl *stringLib) *stringLib {
sl.version = version
return sl
// StringsVersion configures the version of the string library.
//
// The version limits which functions are available. Only functions introduced
// below or equal to the given version included in the library. If this option
// is not set, all functions are available.
//
// See the library documentation to determine which version a function was introduced.
// If the documentation does not state which version a function was introduced, it can
// be assumed to be introduced at version 0, when the library was first created.
func StringsVersion(version uint32) StringsOption {
return func(lib *stringLib) *stringLib {
lib.version = version
return lib
}
}
// CompileOptions implements the Library interface method.
func (sl *stringLib) CompileOptions() []cel.EnvOption {
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
if sl.locale != "" {
if lib.locale != "" {
// ensure locale is properly-formed if set
_, err := language.Parse(sl.locale)
_, err := language.Parse(lib.locale)
if err != nil {
return []cel.EnvOption{
func(e *cel.Env) (*cel.Env, error) {
@ -328,7 +330,7 @@ func (sl *stringLib) CompileOptions() []cel.EnvOption {
},
}
}
formatLocale = sl.locale
formatLocale = lib.locale
}
opts := []cel.EnvOption{
@ -432,7 +434,7 @@ func (sl *stringLib) CompileOptions() []cel.EnvOption {
return stringOrError(upperASCII(string(s)))
}))),
}
if sl.version >= 1 {
if lib.version >= 1 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
@ -447,7 +449,7 @@ func (sl *stringLib) CompileOptions() []cel.EnvOption {
}))))
}
if sl.version >= 2 {
if lib.version >= 2 {
opts = append(opts,
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,

View File

@ -25,13 +25,14 @@ go_library(
importpath = "github.com/google/cel-go/interpreter",
deps = [
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter/functions:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",
@ -56,12 +57,13 @@ go_test(
],
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/containers:go_default_library",
"//common/debug:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",

View File

@ -58,7 +58,7 @@ func (emptyActivation) Parent() Activation { return nil }
// The output of the lazy binding will overwrite the variable reference in the internal map.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
// the types.Adapter configured in the environment.
func NewActivation(bindings any) (Activation, error) {
if bindings == nil {
return nil, errors.New("bindings must be non-nil")

View File

@ -15,6 +15,8 @@
package interpreter
import (
"fmt"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
@ -177,8 +179,8 @@ func numericValueEquals(value any, celValue ref.Val) bool {
// NewPartialAttributeFactory returns an AttributeFactory implementation capable of performing
// AttributePattern matches with PartialActivation inputs.
func NewPartialAttributeFactory(container *containers.Container,
adapter ref.TypeAdapter,
provider ref.TypeProvider) AttributeFactory {
adapter types.Adapter,
provider types.Provider) AttributeFactory {
fac := NewAttributeFactory(container, adapter, provider)
return &partialAttributeFactory{
AttributeFactory: fac,
@ -191,8 +193,8 @@ func NewPartialAttributeFactory(container *containers.Container,
type partialAttributeFactory struct {
AttributeFactory
container *containers.Container
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
}
// AbsoluteAttribute implementation of the AttributeFactory interface which wraps the
@ -241,12 +243,15 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
vars PartialActivation,
attrID int64,
variableNames []string,
qualifiers []Qualifier) (types.Unknown, error) {
qualifiers []Qualifier) (*types.Unknown, error) {
patterns := vars.UnknownAttributePatterns()
candidateIndices := map[int]struct{}{}
for _, variable := range variableNames {
for i, pat := range patterns {
if pat.VariableMatches(variable) {
if len(qualifiers) == 0 {
return types.NewUnknown(attrID, types.NewAttributeTrail(variable)), nil
}
candidateIndices[i] = struct{}{}
}
}
@ -255,10 +260,6 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
if len(candidateIndices) == 0 {
return nil, nil
}
// Determine whether to return early if there are no qualifiers.
if len(qualifiers) == 0 {
return types.Unknown{attrID}, nil
}
// Resolve the attribute qualifiers into a static set. This prevents more dynamic
// Attribute resolutions than necessary when there are multiple unknown patterns
// that traverse the same Attribute-based qualifier field.
@ -300,7 +301,28 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
}
}
if isUnk {
return types.Unknown{matchExprID}, nil
attr := types.NewAttributeTrail(pat.variable)
for i := 0; i < len(qualPats) && i < len(newQuals); i++ {
if qual, ok := newQuals[i].(ConstantQualifier); ok {
switch v := qual.Value().Value().(type) {
case bool:
types.QualifyAttribute[bool](attr, v)
case float64:
types.QualifyAttribute[int64](attr, int64(v))
case int64:
types.QualifyAttribute[int64](attr, v)
case string:
types.QualifyAttribute[string](attr, v)
case uint64:
types.QualifyAttribute[uint64](attr, v)
default:
types.QualifyAttribute[string](attr, fmt.Sprintf("%v", v))
}
} else {
types.QualifyAttribute[string](attr, "*")
}
}
return types.NewUnknown(matchExprID, attr), nil
}
}
return nil, nil

View File

@ -22,8 +22,6 @@ import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// AttributeFactory provides methods creating Attribute and Qualifier values.
@ -61,7 +59,7 @@ type AttributeFactory interface {
// The qualifier may consider the object type being qualified, if present. If absent, the
// qualification should be considered dynamic and the qualification should still work, though
// it may be sub-optimal.
NewQualifier(objType *exprpb.Type, qualID int64, val any, opt bool) (Qualifier, error)
NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error)
}
// Qualifier marker interface for designating different qualifier values and where they appear
@ -131,7 +129,7 @@ type NamespacedAttribute interface {
// NewAttributeFactory returns a default AttributeFactory which is produces Attribute values
// capable of resolving types by simple names and qualify the values using the supported qualifier
// types: bool, int, string, and uint.
func NewAttributeFactory(cont *containers.Container, a ref.TypeAdapter, p ref.TypeProvider) AttributeFactory {
func NewAttributeFactory(cont *containers.Container, a types.Adapter, p types.Provider) AttributeFactory {
return &attrFactory{
container: cont,
adapter: a,
@ -141,8 +139,8 @@ func NewAttributeFactory(cont *containers.Container, a ref.TypeAdapter, p ref.Ty
type attrFactory struct {
container *containers.Container
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
}
// AbsoluteAttribute refers to a variable value and an optional qualifier path.
@ -199,13 +197,13 @@ func (r *attrFactory) RelativeAttribute(id int64, operand Interpretable) Attribu
}
// NewQualifier is an implementation of the AttributeFactory interface.
func (r *attrFactory) NewQualifier(objType *exprpb.Type, qualID int64, val any, opt bool) (Qualifier, error) {
func (r *attrFactory) NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error) {
// Before creating a new qualifier check to see if this is a protobuf message field access.
// If so, use the precomputed GetFrom qualification method rather than the standard
// stringQualifier.
str, isStr := val.(string)
if isStr && objType != nil && objType.GetMessageType() != "" {
ft, found := r.provider.FindFieldType(objType.GetMessageType(), str)
if isStr && objType != nil && objType.Kind() == types.StructKind {
ft, found := r.provider.FindStructFieldType(objType.TypeName(), str)
if found && ft.IsSet != nil && ft.GetFrom != nil {
return &fieldQualifier{
id: qualID,
@ -225,8 +223,8 @@ type absoluteAttribute struct {
// (package) of the expression.
namespaceNames []string
qualifiers []Qualifier
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
fac AttributeFactory
}
@ -325,7 +323,7 @@ type conditionalAttribute struct {
expr Interpretable
truthy Attribute
falsy Attribute
adapter ref.TypeAdapter
adapter types.Adapter
fac AttributeFactory
}
@ -393,8 +391,8 @@ func (a *conditionalAttribute) String() string {
type maybeAttribute struct {
id int64
attrs []NamespacedAttribute
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
fac AttributeFactory
}
@ -511,7 +509,7 @@ type relativeAttribute struct {
id int64
operand Interpretable
qualifiers []Qualifier
adapter ref.TypeAdapter
adapter types.Adapter
fac AttributeFactory
}
@ -576,7 +574,7 @@ func (a *relativeAttribute) String() string {
return fmt.Sprintf("id: %v, operand: %v", a.id, a.operand)
}
func newQualifier(adapter ref.TypeAdapter, id int64, v any, opt bool) (Qualifier, error) {
func newQualifier(adapter types.Adapter, id int64, v any, opt bool) (Qualifier, error) {
var qual Qualifier
switch val := v.(type) {
case Attribute:
@ -657,7 +655,7 @@ func newQualifier(adapter ref.TypeAdapter, id int64, v any, opt bool) (Qualifier
qual = &doubleQualifier{
id: id, value: float64(val), celValue: val, adapter: adapter, optional: opt,
}
case types.Unknown:
case *types.Unknown:
qual = &unknownQualifier{id: id, value: val}
default:
if q, ok := v.(Qualifier); ok {
@ -689,7 +687,7 @@ type stringQualifier struct {
id int64
value string
celValue ref.Val
adapter ref.TypeAdapter
adapter types.Adapter
optional bool
}
@ -790,7 +788,7 @@ type intQualifier struct {
id int64
value int64
celValue ref.Val
adapter ref.TypeAdapter
adapter types.Adapter
optional bool
}
@ -917,7 +915,7 @@ type uintQualifier struct {
id int64
value uint64
celValue ref.Val
adapter ref.TypeAdapter
adapter types.Adapter
optional bool
}
@ -982,7 +980,7 @@ type boolQualifier struct {
id int64
value bool
celValue ref.Val
adapter ref.TypeAdapter
adapter types.Adapter
optional bool
}
@ -1035,8 +1033,8 @@ func (q *boolQualifier) Value() ref.Val {
type fieldQualifier struct {
id int64
Name string
FieldType *ref.FieldType
adapter ref.TypeAdapter
FieldType *types.FieldType
adapter types.Adapter
optional bool
}
@ -1094,7 +1092,7 @@ type doubleQualifier struct {
id int64
value float64
celValue ref.Val
adapter ref.TypeAdapter
adapter types.Adapter
optional bool
}
@ -1131,7 +1129,7 @@ func (q *doubleQualifier) Value() ref.Val {
// for any value subject to qualification. This is consistent with CEL's unknown handling elsewhere.
type unknownQualifier struct {
id int64
value types.Unknown
value *types.Unknown
}
// ID is an implementation of the Qualifier interface method.
@ -1225,10 +1223,10 @@ func attrQualifyIfPresent(fac AttributeFactory, vars Activation, obj any, qualAt
// refQualify attempts to convert the value to a CEL value and then uses reflection methods to try and
// apply the qualifier with the option to presence test field accesses before retrieving field values.
func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, presenceOnly bool) (ref.Val, bool, error) {
func refQualify(adapter types.Adapter, obj any, idx ref.Val, presenceTest, presenceOnly bool) (ref.Val, bool, error) {
celVal := adapter.NativeToValue(obj)
switch v := celVal.(type) {
case types.Unknown:
case *types.Unknown:
return v, true, nil
case *types.Err:
return nil, false, v

View File

@ -75,15 +75,13 @@ func decDisableShortcircuits() InterpretableDecorator {
switch expr := i.(type) {
case *evalOr:
return &evalExhaustiveOr{
id: expr.id,
lhs: expr.lhs,
rhs: expr.rhs,
id: expr.id,
terms: expr.terms,
}, nil
case *evalAnd:
return &evalExhaustiveAnd{
id: expr.id,
lhs: expr.lhs,
rhs: expr.rhs,
id: expr.id,
terms: expr.terms,
}, nil
case *evalFold:
expr.exhaustive = true

View File

@ -17,7 +17,7 @@ package interpreter
import (
"fmt"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/common/functions"
)
// Dispatcher resolves function calls to their appropriate overload.

View File

@ -66,7 +66,11 @@ func (s *evalState) Value(exprID int64) (ref.Val, bool) {
// SetValue is an implementation of the EvalState interface method.
func (s *evalState) SetValue(exprID int64, val ref.Val) {
s.values[exprID] = val
if val == nil {
delete(s.values, exprID)
} else {
s.values[exprID] = val
}
}
// Reset implements the EvalState interface method.

View File

@ -25,7 +25,7 @@ import (
"github.com/google/cel-go/common/types/ref"
)
type typeVerifier func(int64, ...*types.TypeValue) (bool, error)
type typeVerifier func(int64, ...ref.Type) (bool, error)
// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports
// any errors at compile time.

View File

@ -7,16 +7,11 @@ package(
go_library(
name = "go_default_library",
srcs = [
srcs = [
"functions.go",
"standard.go",
],
importpath = "github.com/google/cel-go/interpreter/functions",
deps = [
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//common/functions:go_default_library",
],
)

View File

@ -16,7 +16,7 @@
// interpreter and as declared within the checker#StandardDeclarations.
package functions
import "github.com/google/cel-go/common/types/ref"
import fn "github.com/google/cel-go/common/functions"
// Overload defines a named overload of a function, indicating an operand trait
// which must be present on the first argument to the overload as well as one
@ -26,37 +26,14 @@ import "github.com/google/cel-go/common/types/ref"
// and the specializations simplify the call contract for implementers of
// types with operator overloads. Any added complexity is assumed to be handled
// by the generic FunctionOp.
type Overload struct {
// Operator name as written in an expression or defined within
// operators.go.
Operator string
// Operand trait used to dispatch the call. The zero-value indicates a
// global function overload or that one of the Unary / Binary / Function
// definitions should be used to execute the call.
OperandTrait int
// Unary defines the overload with a UnaryOp implementation. May be nil.
Unary UnaryOp
// Binary defines the overload with a BinaryOp implementation. May be nil.
Binary BinaryOp
// Function defines the overload with a FunctionOp implementation. May be
// nil.
Function FunctionOp
// NonStrict specifies whether the Overload will tolerate arguments that
// are types.Err or types.Unknown.
NonStrict bool
}
type Overload = fn.Overload
// UnaryOp is a function that takes a single value and produces an output.
type UnaryOp func(value ref.Val) ref.Val
type UnaryOp = fn.UnaryOp
// BinaryOp is a function that takes two values and produces an output.
type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val
type BinaryOp = fn.BinaryOp
// FunctionOp is a function with accepts zero or more arguments and produces
// a value or error as a result.
type FunctionOp func(values ...ref.Val) ref.Val
type FunctionOp = fn.FunctionOp

View File

@ -1,270 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package functions
import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// StandardOverloads returns the definitions of the built-in overloads.
func StandardOverloads() []*Overload {
return []*Overload{
// Logical not (!a)
{
Operator: operators.LogicalNot,
OperandTrait: traits.NegatorType,
Unary: func(value ref.Val) ref.Val {
if !types.IsBool(value) {
return types.ValOrErr(value, "no such overload")
}
return value.(traits.Negater).Negate()
}},
// Not strictly false: IsBool(a) ? a : true
{
Operator: operators.NotStrictlyFalse,
Unary: notStrictlyFalse},
// Deprecated: not strictly false, may be overridden in the environment.
{
Operator: operators.OldNotStrictlyFalse,
Unary: notStrictlyFalse},
// Less than operator
{Operator: operators.Less,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne {
return types.True
}
if cmp == types.IntOne || cmp == types.IntZero {
return types.False
}
return cmp
}},
// Less than or equal operator
{Operator: operators.LessEquals,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntOne {
return types.False
}
return cmp
}},
// Greater than operator
{Operator: operators.Greater,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne {
return types.True
}
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.False
}
return cmp
}},
// Greater than equal operators
{Operator: operators.GreaterEquals,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntNegOne {
return types.False
}
return cmp
}},
// Add operator
{Operator: operators.Add,
OperandTrait: traits.AdderType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Adder).Add(rhs)
}},
// Subtract operators
{Operator: operators.Subtract,
OperandTrait: traits.SubtractorType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Subtractor).Subtract(rhs)
}},
// Multiply operator
{Operator: operators.Multiply,
OperandTrait: traits.MultiplierType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Multiplier).Multiply(rhs)
}},
// Divide operator
{Operator: operators.Divide,
OperandTrait: traits.DividerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Divider).Divide(rhs)
}},
// Modulo operator
{Operator: operators.Modulo,
OperandTrait: traits.ModderType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Modder).Modulo(rhs)
}},
// Negate operator
{Operator: operators.Negate,
OperandTrait: traits.NegatorType,
Unary: func(value ref.Val) ref.Val {
if types.IsBool(value) {
return types.ValOrErr(value, "no such overload")
}
return value.(traits.Negater).Negate()
}},
// Index operator
{Operator: operators.Index,
OperandTrait: traits.IndexerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Indexer).Get(rhs)
}},
// Size function
{Operator: overloads.Size,
OperandTrait: traits.SizerType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Sizer).Size()
}},
// In operator
{Operator: operators.In, Binary: inAggregate},
// Deprecated: in operator, may be overridden in the environment.
{Operator: operators.OldIn, Binary: inAggregate},
// Matches function
{Operator: overloads.Matches,
OperandTrait: traits.MatcherType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Matcher).Match(rhs)
}},
// Type conversion functions
// TODO: verify type conversion safety of numeric values.
// Int conversions.
{Operator: overloads.TypeConvertInt,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.IntType)
}},
// Uint conversions.
{Operator: overloads.TypeConvertUint,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.UintType)
}},
// Double conversions.
{Operator: overloads.TypeConvertDouble,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.DoubleType)
}},
// Bool conversions.
{Operator: overloads.TypeConvertBool,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.BoolType)
}},
// Bytes conversions.
{Operator: overloads.TypeConvertBytes,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.BytesType)
}},
// String conversions.
{Operator: overloads.TypeConvertString,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.StringType)
}},
// Timestamp conversions.
{Operator: overloads.TypeConvertTimestamp,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.TimestampType)
}},
// Duration conversions.
{Operator: overloads.TypeConvertDuration,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.DurationType)
}},
// Type operations.
{Operator: overloads.TypeConvertType,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.TypeType)
}},
// Dyn conversion (identity function).
{Operator: overloads.TypeConvertDyn,
Unary: func(value ref.Val) ref.Val {
return value
}},
{Operator: overloads.Iterator,
OperandTrait: traits.IterableType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterable).Iterator()
}},
{Operator: overloads.HasNext,
OperandTrait: traits.IteratorType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterator).HasNext()
}},
{Operator: overloads.Next,
OperandTrait: traits.IteratorType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterator).Next()
}},
}
}
func notStrictlyFalse(value ref.Val) ref.Val {
if types.IsBool(value) {
return value
}
return types.True
}
func inAggregate(lhs ref.Val, rhs ref.Val) ref.Val {
if rhs.Type().HasTrait(traits.ContainerType) {
return rhs.(traits.Container).Contains(lhs)
}
return types.ValOrErr(rhs, "no such overload")
}

View File

@ -17,12 +17,12 @@ package interpreter
import (
"fmt"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter/functions"
)
// Interpretable can accept a given Activation and produce a value along with
@ -52,7 +52,7 @@ type InterpretableAttribute interface {
Attr() Attribute
// Adapter returns the type adapter to be used for adapting resolved Attribute values.
Adapter() ref.TypeAdapter
Adapter() types.Adapter
// AddQualifier proxies the Attribute.AddQualifier method.
//
@ -202,9 +202,8 @@ func (cons *evalConst) Value() ref.Val {
}
type evalOr struct {
id int64
lhs Interpretable
rhs Interpretable
id int64
terms []Interpretable
}
// ID implements the Interpretable interface method.
@ -214,41 +213,39 @@ func (or *evalOr) ID() int64 {
// Eval implements the Interpretable interface method.
func (or *evalOr) Eval(ctx Activation) ref.Val {
// short-circuit lhs.
lVal := or.lhs.Eval(ctx)
lBool, lok := lVal.(types.Bool)
if lok && lBool == types.True {
return types.True
var err ref.Val = nil
var unk *types.Unknown
for _, term := range or.terms {
val := term.Eval(ctx)
boolVal, ok := val.(types.Bool)
// short-circuit on true.
if ok && boolVal == types.True {
return types.True
}
if !ok {
isUnk := false
unk, isUnk = types.MaybeMergeUnknowns(val, unk)
if !isUnk && err == nil {
if types.IsError(val) {
err = val
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
}
}
}
// short-circuit on rhs.
rVal := or.rhs.Eval(ctx)
rBool, rok := rVal.(types.Bool)
if rok && rBool == types.True {
return types.True
if unk != nil {
return unk
}
// return if both sides are bool false.
if lok && rok {
return types.False
if err != nil {
return err
}
// TODO: return both values as a set if both are unknown or error.
// prefer left unknown to right unknown.
if types.IsUnknown(lVal) {
return lVal
}
if types.IsUnknown(rVal) {
return rVal
}
// If the left-hand side is non-boolean return it as the error.
if types.IsError(lVal) {
return lVal
}
return types.ValOrErr(rVal, "no such overload")
return types.False
}
type evalAnd struct {
id int64
lhs Interpretable
rhs Interpretable
id int64
terms []Interpretable
}
// ID implements the Interpretable interface method.
@ -258,35 +255,34 @@ func (and *evalAnd) ID() int64 {
// Eval implements the Interpretable interface method.
func (and *evalAnd) Eval(ctx Activation) ref.Val {
// short-circuit lhs.
lVal := and.lhs.Eval(ctx)
lBool, lok := lVal.(types.Bool)
if lok && lBool == types.False {
return types.False
var err ref.Val = nil
var unk *types.Unknown
for _, term := range and.terms {
val := term.Eval(ctx)
boolVal, ok := val.(types.Bool)
// short-circuit on false.
if ok && boolVal == types.False {
return types.False
}
if !ok {
isUnk := false
unk, isUnk = types.MaybeMergeUnknowns(val, unk)
if !isUnk && err == nil {
if types.IsError(val) {
err = val
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
}
}
}
// short-circuit on rhs.
rVal := and.rhs.Eval(ctx)
rBool, rok := rVal.(types.Bool)
if rok && rBool == types.False {
return types.False
if unk != nil {
return unk
}
// return if both sides are bool true.
if lok && rok {
return types.True
if err != nil {
return err
}
// TODO: return both values as a set if both are unknown or error.
// prefer left unknown to right unknown.
if types.IsUnknown(lVal) {
return lVal
}
if types.IsUnknown(rVal) {
return rVal
}
// If the left-hand side is non-boolean return it as the error.
if types.IsError(lVal) {
return lVal
}
return types.ValOrErr(rVal, "no such overload")
return types.True
}
type evalEq struct {
@ -579,7 +575,7 @@ type evalList struct {
elems []Interpretable
optionals []bool
hasOptionals bool
adapter ref.TypeAdapter
adapter types.Adapter
}
// ID implements the Interpretable interface method.
@ -625,7 +621,7 @@ type evalMap struct {
vals []Interpretable
optionals []bool
hasOptionals bool
adapter ref.TypeAdapter
adapter types.Adapter
}
// ID implements the Interpretable interface method.
@ -689,7 +685,7 @@ type evalObj struct {
vals []Interpretable
optionals []bool
hasOptionals bool
provider ref.TypeProvider
provider types.Provider
}
// ID implements the Interpretable interface method.
@ -739,7 +735,7 @@ type evalFold struct {
cond Interpretable
step Interpretable
result Interpretable
adapter ref.TypeAdapter
adapter types.Adapter
exhaustive bool
interruptable bool
}
@ -865,18 +861,40 @@ type evalWatchAttr struct {
// AddQualifier creates a wrapper over the incoming qualifier which observes the qualification
// result.
func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) {
cq, isConst := q.(ConstantQualifier)
if isConst {
switch qual := q.(type) {
// By default, the qualifier is either a constant or an attribute
// There may be some custom cases where the attribute is neither.
case ConstantQualifier:
// Expose a method to test whether the qualifier matches the input pattern.
q = &evalWatchConstQual{
ConstantQualifier: cq,
ConstantQualifier: qual,
observer: e.observer,
adapter: e.InterpretableAttribute.Adapter(),
adapter: e.Adapter(),
}
} else {
q = &evalWatchQual{
Qualifier: q,
case *evalWatchAttr:
// Unwrap the evalWatchAttr since the observation will be applied during Qualify or
// QualifyIfPresent rather than Eval.
q = &evalWatchAttrQual{
Attribute: qual.InterpretableAttribute,
observer: e.observer,
adapter: e.InterpretableAttribute.Adapter(),
adapter: e.Adapter(),
}
case Attribute:
// Expose methods which intercept the qualification prior to being applied as a qualifier.
// Using this interface ensures that the qualifier is converted to a constant value one
// time during attribute pattern matching as the method embeds the Attribute interface
// needed to trip the conversion to a constant.
q = &evalWatchAttrQual{
Attribute: qual,
observer: e.observer,
adapter: e.Adapter(),
}
default:
// This is likely a custom qualifier type.
q = &evalWatchQual{
Qualifier: qual,
observer: e.observer,
adapter: e.Adapter(),
}
}
_, err := e.InterpretableAttribute.AddQualifier(q)
@ -895,7 +913,7 @@ func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
type evalWatchConstQual struct {
ConstantQualifier
observer EvalObserver
adapter ref.TypeAdapter
adapter types.Adapter
}
// Qualify observes the qualification of a object via a constant boolean, int, string, or uint.
@ -934,11 +952,48 @@ func (e *evalWatchConstQual) QualifierValueEquals(value any) bool {
return ok && qve.QualifierValueEquals(value)
}
// evalWatchAttrQual observes the qualification of an object by a value computed at runtime.
type evalWatchAttrQual struct {
Attribute
observer EvalObserver
adapter ref.TypeAdapter
}
// Qualify observes the qualification of a object via a value computed at runtime.
func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) {
out, err := e.Attribute.Qualify(vars, obj)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), e.Attribute, val)
return out, err
}
// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present.
func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly)
var val ref.Val
if err != nil {
val = types.WrapErr(err)
} else if out != nil {
val = e.adapter.NativeToValue(out)
} else if presenceOnly {
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.Attribute, val)
}
return out, present, err
}
// evalWatchQual observes the qualification of an object by a value computed at runtime.
type evalWatchQual struct {
Qualifier
observer EvalObserver
adapter ref.TypeAdapter
adapter types.Adapter
}
// Qualify observes the qualification of a object via a value computed at runtime.
@ -986,9 +1041,8 @@ func (e *evalWatchConst) Eval(vars Activation) ref.Val {
// evalExhaustiveOr is just like evalOr, but does not short-circuit argument evaluation.
type evalExhaustiveOr struct {
id int64
lhs Interpretable
rhs Interpretable
id int64
terms []Interpretable
}
// ID implements the Interpretable interface method.
@ -998,38 +1052,44 @@ func (or *evalExhaustiveOr) ID() int64 {
// Eval implements the Interpretable interface method.
func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val {
lVal := or.lhs.Eval(ctx)
rVal := or.rhs.Eval(ctx)
lBool, lok := lVal.(types.Bool)
if lok && lBool == types.True {
var err ref.Val = nil
var unk *types.Unknown
isTrue := false
for _, term := range or.terms {
val := term.Eval(ctx)
boolVal, ok := val.(types.Bool)
// flag the result as true
if ok && boolVal == types.True {
isTrue = true
}
if !ok && !isTrue {
isUnk := false
unk, isUnk = types.MaybeMergeUnknowns(val, unk)
if !isUnk && err == nil {
if types.IsError(val) {
err = val
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
}
}
}
if isTrue {
return types.True
}
rBool, rok := rVal.(types.Bool)
if rok && rBool == types.True {
return types.True
if unk != nil {
return unk
}
if lok && rok {
return types.False
if err != nil {
return err
}
if types.IsUnknown(lVal) {
return lVal
}
if types.IsUnknown(rVal) {
return rVal
}
// TODO: Combine the errors into a set in the future.
// If the left-hand side is non-boolean return it as the error.
if types.IsError(lVal) {
return lVal
}
return types.MaybeNoSuchOverloadErr(rVal)
return types.False
}
// evalExhaustiveAnd is just like evalAnd, but does not short-circuit argument evaluation.
type evalExhaustiveAnd struct {
id int64
lhs Interpretable
rhs Interpretable
id int64
terms []Interpretable
}
// ID implements the Interpretable interface method.
@ -1039,38 +1099,45 @@ func (and *evalExhaustiveAnd) ID() int64 {
// Eval implements the Interpretable interface method.
func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val {
lVal := and.lhs.Eval(ctx)
rVal := and.rhs.Eval(ctx)
lBool, lok := lVal.(types.Bool)
if lok && lBool == types.False {
var err ref.Val = nil
var unk *types.Unknown
isFalse := false
for _, term := range and.terms {
val := term.Eval(ctx)
boolVal, ok := val.(types.Bool)
// short-circuit on false.
if ok && boolVal == types.False {
isFalse = true
}
if !ok && !isFalse {
isUnk := false
unk, isUnk = types.MaybeMergeUnknowns(val, unk)
if !isUnk && err == nil {
if types.IsError(val) {
err = val
} else {
err = types.MaybeNoSuchOverloadErr(val)
}
}
}
}
if isFalse {
return types.False
}
rBool, rok := rVal.(types.Bool)
if rok && rBool == types.False {
return types.False
if unk != nil {
return unk
}
if lok && rok {
return types.True
if err != nil {
return err
}
if types.IsUnknown(lVal) {
return lVal
}
if types.IsUnknown(rVal) {
return rVal
}
// TODO: Combine the errors into a set in the future.
// If the left-hand side is non-boolean return it as the error.
if types.IsError(lVal) {
return lVal
}
return types.MaybeNoSuchOverloadErr(rVal)
return types.True
}
// evalExhaustiveConditional is like evalConditional, but does not short-circuit argument
// evaluation.
type evalExhaustiveConditional struct {
id int64
adapter ref.TypeAdapter
adapter types.Adapter
attr *conditionalAttribute
}
@ -1102,7 +1169,7 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val {
// evalAttr evaluates an Attribute value.
type evalAttr struct {
adapter ref.TypeAdapter
adapter types.Adapter
attr Attribute
optional bool
}
@ -1127,7 +1194,7 @@ func (a *evalAttr) Attr() Attribute {
}
// Adapter implements the InterpretableAttribute interface method.
func (a *evalAttr) Adapter() ref.TypeAdapter {
func (a *evalAttr) Adapter() types.Adapter {
return a.adapter
}

View File

@ -18,9 +18,10 @@
package interpreter
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -29,7 +30,7 @@ import (
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(checked *exprpb.CheckedExpr, decorators ...InterpretableDecorator) (Interpretable, error)
NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error)
// NewUncheckedInterpretable returns an Interpretable from a parsed expression
// and an optional list of InterpretableDecorator values.
@ -154,8 +155,8 @@ func CompileRegexConstants(regexOptimizations ...*RegexOptimization) Interpretab
type exprInterpreter struct {
dispatcher Dispatcher
container *containers.Container
provider ref.TypeProvider
adapter ref.TypeAdapter
provider types.Provider
adapter types.Adapter
attrFactory AttributeFactory
}
@ -163,8 +164,8 @@ type exprInterpreter struct {
// throughout the Eval of all Interpretable instances generated from it.
func NewInterpreter(dispatcher Dispatcher,
container *containers.Container,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory) Interpreter {
return &exprInterpreter{
dispatcher: dispatcher,
@ -174,20 +175,9 @@ func NewInterpreter(dispatcher Dispatcher,
attrFactory: attrFactory}
}
// NewStandardInterpreter builds a Dispatcher and TypeProvider with support for all of the CEL
// builtins defined in the language definition.
func NewStandardInterpreter(container *containers.Container,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
resolver AttributeFactory) Interpreter {
dispatcher := NewDispatcher()
dispatcher.Add(functions.StandardOverloads()...)
return NewInterpreter(dispatcher, container, provider, adapter, resolver)
}
// NewIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewInterpretable(
checked *exprpb.CheckedExpr,
checked *ast.CheckedAST,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newPlanner(
i.dispatcher,
@ -197,7 +187,7 @@ func (i *exprInterpreter) NewInterpretable(
i.container,
checked,
decorators...)
return p.Plan(checked.GetExpr())
return p.Plan(checked.Expr)
}
// NewUncheckedIntepretable implements the Interpreter interface method.

View File

@ -18,10 +18,12 @@ import (
"fmt"
"strings"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -37,11 +39,11 @@ type interpretablePlanner interface {
// functions, types, and namespaced identifiers at plan time rather than at runtime since
// it only needs to be done once and may be semi-expensive to compute.
func newPlanner(disp Dispatcher,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
checked *exprpb.CheckedExpr,
checked *ast.CheckedAST,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
@ -49,8 +51,8 @@ func newPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: checked.GetReferenceMap(),
typeMap: checked.GetTypeMap(),
refMap: checked.ReferenceMap,
typeMap: checked.TypeMap,
decorators: decorators,
}
}
@ -59,8 +61,8 @@ func newPlanner(disp Dispatcher,
// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in
// Select expressions are resolved lazily at evaluation time.
func newUncheckedPlanner(disp Dispatcher,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
decorators ...InterpretableDecorator) interpretablePlanner {
@ -70,8 +72,8 @@ func newUncheckedPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: make(map[int64]*exprpb.Reference),
typeMap: make(map[int64]*exprpb.Type),
refMap: make(map[int64]*ast.ReferenceInfo),
typeMap: make(map[int64]*types.Type),
decorators: decorators,
}
}
@ -79,12 +81,12 @@ func newUncheckedPlanner(disp Dispatcher,
// planner is an implementation of the interpretablePlanner interface.
type planner struct {
disp Dispatcher
provider ref.TypeProvider
adapter ref.TypeAdapter
provider types.Provider
adapter types.Adapter
attrFactory AttributeFactory
container *containers.Container
refMap map[int64]*exprpb.Reference
typeMap map[int64]*exprpb.Type
refMap map[int64]*ast.ReferenceInfo
typeMap map[int64]*types.Type
decorators []InterpretableDecorator
}
@ -143,22 +145,19 @@ func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) {
}, nil
}
func (p *planner) planCheckedIdent(id int64, identRef *exprpb.Reference) (Interpretable, error) {
func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) {
// Plan a constant reference if this is the case for this simple identifier.
if identRef.GetValue() != nil {
return p.Plan(&exprpb.Expr{Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: identRef.GetValue(),
}})
if identRef.Value != nil {
return NewConstValue(id, identRef.Value), nil
}
// Check to see whether the type map indicates this is a type name. All types should be
// registered with the provider.
cType := p.typeMap[id]
if cType.GetType() != nil {
cVal, found := p.provider.FindIdent(identRef.GetName())
if cType.Kind() == types.TypeKind {
cVal, found := p.provider.FindIdent(identRef.Name)
if !found {
return nil, fmt.Errorf("reference to undefined type: %s", identRef.GetName())
return nil, fmt.Errorf("reference to undefined type: %s", identRef.Name)
}
return NewConstValue(id, cVal), nil
}
@ -166,7 +165,7 @@ func (p *planner) planCheckedIdent(id int64, identRef *exprpb.Reference) (Interp
// Otherwise, return the attribute for the resolved identifier name.
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.AbsoluteAttribute(id, identRef.GetName()),
attr: p.attrFactory.AbsoluteAttribute(id, identRef.Name),
}, nil
}
@ -429,18 +428,16 @@ func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Int
// planCallLogicalAnd generates a logical and (&&) Interpretable.
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalAnd{
id: expr.GetId(),
lhs: args[0],
rhs: args[1],
id: expr.GetId(),
terms: args,
}, nil
}
// planCallLogicalOr generates a logical or (||) Interpretable.
func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalOr{
id: expr.GetId(),
lhs: args[0],
rhs: args[1],
id: expr.GetId(),
terms: args,
}, nil
}
@ -476,7 +473,7 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (
func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) {
op := args[0]
ind := args[1]
opType := p.typeMap[expr.GetCallExpr().GetTarget().GetId()]
opType := p.typeMap[op.ID()]
// Establish the attribute reference.
var err error
@ -675,7 +672,7 @@ func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) {
// namespace resolution rules to it in a scan over possible matching types in the TypeProvider.
func (p *planner) resolveTypeName(typeName string) (string, bool) {
for _, qualifiedTypeName := range p.container.ResolveCandidateNames(typeName) {
if _, found := p.provider.FindType(qualifiedTypeName); found {
if _, found := p.provider.FindStructType(qualifiedTypeName); found {
return qualifiedTypeName, true
}
}
@ -702,8 +699,8 @@ func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, stri
// function name as the fnName value.
oRef, hasOverload := p.refMap[expr.GetId()]
if hasOverload {
if len(oRef.GetOverloadId()) == 1 {
return target, fnName, oRef.GetOverloadId()[0]
if len(oRef.OverloadIDs) == 1 {
return target, fnName, oRef.OverloadIDs[0]
}
// Note, this namespaced function name will not appear as a fully qualified name in ASTs
// built and stored before cel-go v0.5.0; however, this functionality did not work at all

View File

@ -341,6 +341,11 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
compre := node.GetComprehensionExpr()
if compre != nil {
visit(macro, clearIterVarVisitor(compre.IterVar, p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
@ -488,6 +493,27 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
},
}, true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
}
}
return node, false
}
@ -524,6 +550,17 @@ func getMaxID(expr *exprpb.Expr) int64 {
return maxID
}
func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
ident := e.GetIdentExpr()
if ident != nil && ident.GetName() == varName {
state.SetValue(e.GetId(), nil)
}
},
}
}
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
@ -543,7 +580,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
visitor.visitExpr(e)
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
@ -567,7 +606,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
visitor.visitEntry(entry)
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}

View File

@ -65,13 +65,21 @@ func CostObserver(tracker *CostTracker) EvalObserver {
// While the field names are identical, the boolean operation eval structs do not share an interface and so
// must be handled individually.
case *evalOr:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalAnd:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveOr:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveAnd:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalFold:
tracker.stack.drop(t.iterRange.ID())
case Qualifier:
@ -125,6 +133,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption {
func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) {
tracker := &CostTracker{
Estimator: estimator,
overloadTrackers: map[string]FunctionTracker{},
presenceTestHasCost: true,
}
for _, opt := range opts {
@ -136,9 +145,24 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*
return tracker, nil
}
// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation.
//
// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or
// optional cost tracking changes.
func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.overloadTrackers[overloadID] = fnTracker
return nil
}
}
// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result.
type FunctionTracker func(args []ref.Val, result ref.Val) *uint64
// CostTracker represents the information needed for tracking runtime cost.
type CostTracker struct {
Estimator ActualCostEstimator
overloadTrackers map[string]FunctionTracker
Limit *uint64
presenceTestHasCost bool
@ -151,10 +175,19 @@ func (c *CostTracker) ActualCost() uint64 {
return c.cost
}
func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 {
var cost uint64
if len(c.overloadTrackers) != 0 {
if tracker, found := c.overloadTrackers[call.OverloadID()]; found {
callCost := tracker(args, result)
if callCost != nil {
cost += *callCost
return cost
}
}
}
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result)
if callCost != nil {
cost += *callCost
return cost
@ -165,11 +198,11 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
cost += c.actualSize(args[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
@ -177,8 +210,8 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
lhsSize := c.actualSize(args[0])
rhsSize := c.actualSize(args[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
@ -187,23 +220,23 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost
default:

View File

@ -22,9 +22,22 @@ import (
// parseErrors is a specialization of Errors.
type parseErrors struct {
*common.Errors
errs *common.Errors
}
// errorCount indicates the number of errors reported.
func (e *parseErrors) errorCount() int {
return len(e.errs.GetErrors())
}
func (e *parseErrors) internalError(message string) {
e.errs.ReportErrorAtID(0, common.NoLocation, message)
}
func (e *parseErrors) syntaxError(l common.Location, message string) {
e.ReportError(l, fmt.Sprintf("Syntax error: %s", message))
e.errs.ReportErrorAtID(0, l, fmt.Sprintf("Syntax error: %s", message))
}
func (e *parseErrors) reportErrorAtID(id int64, l common.Location, message string, args ...any) {
e.errs.ReportErrorAtID(id, l, message, args...)
}

View File

@ -193,15 +193,15 @@ func (p *parserHelper) newExpr(ctx any) *exprpb.Expr {
func (p *parserHelper) id(ctx any) int64 {
var location common.Location
switch ctx.(type) {
switch c := ctx.(type) {
case antlr.ParserRuleContext:
token := (ctx.(antlr.ParserRuleContext)).GetStart()
token := c.GetStart()
location = p.source.NewLocation(token.GetLine(), token.GetColumn())
case antlr.Token:
token := ctx.(antlr.Token)
token := c
location = p.source.NewLocation(token.GetLine(), token.GetColumn())
case common.Location:
location = ctx.(common.Location)
location = c
default:
// This should only happen if the ctx is nil
return -1
@ -297,67 +297,83 @@ func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprp
}
}
// balancer performs tree balancing on operators whose arguments are of equal precedence.
// logicManager compacts logical trees into a more efficient structure which is semantically
// equivalent with how the logic graph is constructed by the ANTLR parser.
//
// The purpose of the balancer is to ensure a compact serialization format for the logical &&, ||
// The purpose of the logicManager is to ensure a compact serialization format for the logical &&, ||
// operators which have a tendency to create long DAGs which are skewed in one direction. Since the
// operators are commutative re-ordering the terms *must not* affect the evaluation result.
//
// Re-balancing the terms is a safe, if somewhat controversial choice. A better solution would be
// to make these functions variadic and update both the checker and interpreter to understand this;
// however, this is a more complex change.
//
// TODO: Consider replacing tree-balancing with variadic logical &&, || within the parser, checker,
// and interpreter.
type balancer struct {
helper *parserHelper
function string
terms []*exprpb.Expr
ops []int64
// The logic manager will either render the terms to N-chained && / || operators as a single logical
// call with N-terms, or will rebalance the tree. Rebalancing the terms is a safe, if somewhat
// controversial choice as it alters the traditional order of execution assumptions present in most
// expressions.
type logicManager struct {
helper *parserHelper
function string
terms []*exprpb.Expr
ops []int64
variadicASTs bool
}
// newBalancer creates a balancer instance bound to a specific function and its first term.
func newBalancer(h *parserHelper, function string, term *exprpb.Expr) *balancer {
return &balancer{
helper: h,
function: function,
terms: []*exprpb.Expr{term},
ops: []int64{},
// newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term.
func newVariadicLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager {
return &logicManager{
helper: h,
function: function,
terms: []*exprpb.Expr{term},
ops: []int64{},
variadicASTs: true,
}
}
// newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term.
func newBalancingLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager {
return &logicManager{
helper: h,
function: function,
terms: []*exprpb.Expr{term},
ops: []int64{},
variadicASTs: false,
}
}
// addTerm adds an operation identifier and term to the set of terms to be balanced.
func (b *balancer) addTerm(op int64, term *exprpb.Expr) {
b.terms = append(b.terms, term)
b.ops = append(b.ops, op)
func (l *logicManager) addTerm(op int64, term *exprpb.Expr) {
l.terms = append(l.terms, term)
l.ops = append(l.ops, op)
}
// balance creates a balanced tree from the sub-terms and returns the final Expr value.
func (b *balancer) balance() *exprpb.Expr {
if len(b.terms) == 1 {
return b.terms[0]
// toExpr renders the logic graph into an Expr value, either balancing a tree of logical
// operations or creating a variadic representation of the logical operator.
func (l *logicManager) toExpr() *exprpb.Expr {
if len(l.terms) == 1 {
return l.terms[0]
}
return b.balancedTree(0, len(b.ops)-1)
if l.variadicASTs {
return l.helper.newGlobalCall(l.ops[0], l.function, l.terms...)
}
return l.balancedTree(0, len(l.ops)-1)
}
// balancedTree recursively balances the terms provided to a commutative operator.
func (b *balancer) balancedTree(lo, hi int) *exprpb.Expr {
func (l *logicManager) balancedTree(lo, hi int) *exprpb.Expr {
mid := (lo + hi + 1) / 2
var left *exprpb.Expr
if mid == lo {
left = b.terms[mid]
left = l.terms[mid]
} else {
left = b.balancedTree(lo, mid-1)
left = l.balancedTree(lo, mid-1)
}
var right *exprpb.Expr
if mid == hi {
right = b.terms[mid+1]
right = l.terms[mid+1]
} else {
right = b.balancedTree(mid+1, hi)
right = l.balancedTree(mid+1, hi)
}
return b.helper.newGlobalCall(b.ops[mid], b.function, left, right)
return l.helper.newGlobalCall(l.ops[mid], l.function, left, right)
}
type exprHelper struct {
@ -370,7 +386,7 @@ func (e *exprHelper) nextMacroID() int64 {
}
// Copy implements the ExprHelper interface method by producing a copy of the input Expr value
// with a fresh set of numeric identifiers the Expr and all its descendents.
// with a fresh set of numeric identifiers the Expr and all its descendants.
func (e *exprHelper) Copy(expr *exprpb.Expr) *exprpb.Expr {
copy := e.parserHelper.newExpr(e.parserHelper.getLocation(expr.GetId()))
switch expr.GetExprKind().(type) {
@ -558,11 +574,22 @@ func (e *exprHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr {
// OffsetLocation implements the ExprHelper interface method.
func (e *exprHelper) OffsetLocation(exprID int64) common.Location {
offset := e.parserHelper.positions[exprID]
location, _ := e.parserHelper.source.OffsetLocation(offset)
offset, found := e.parserHelper.positions[exprID]
if !found {
return common.NoLocation
}
location, found := e.parserHelper.source.OffsetLocation(offset)
if !found {
return common.NoLocation
}
return location
}
// NewError associates an error message with a given expression id, populating the source offset location of the error if possible.
func (e *exprHelper) NewError(exprID int64, message string) *common.Error {
return common.NewError(exprID, message, e.OffsetLocation(exprID))
}
var (
// Thread-safe pool of ExprHelper values to minimize alloc overhead of ExprHelper creations.
exprHelperPool = &sync.Pool{

View File

@ -232,6 +232,9 @@ type ExprHelper interface {
// OffsetLocation returns the Location of the expression identifier.
OffsetLocation(exprID int64) common.Location
// NewError associates an error message with a given expression id.
NewError(exprID int64, message string) *common.Error
}
var (
@ -324,7 +327,7 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex
func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
return nil, &common.Error{Message: "argument is not an identifier"}
return nil, eh.NewError(args[0].GetId(), "argument is not an identifier")
}
var fn *exprpb.Expr
@ -355,7 +358,7 @@ func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E
func MakeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
return nil, &common.Error{Message: "argument is not an identifier"}
return nil, eh.NewError(args[0].GetId(), "argument is not an identifier")
}
filter := args[1]
@ -372,17 +375,13 @@ func MakeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E
if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok {
return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil
}
return nil, &common.Error{Message: "invalid argument to has() macro"}
return nil, eh.NewError(args[0].GetId(), "invalid argument to has() macro")
}
func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
v, found := extractIdent(args[0])
if !found {
location := eh.OffsetLocation(args[0].GetId())
return nil, &common.Error{
Message: "argument must be a simple name",
Location: location,
}
return nil, eh.NewError(args[0].GetId(), "argument must be a simple name")
}
var init *exprpb.Expr
@ -411,7 +410,7 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, arg
eh.GlobalCall(operators.Add, eh.AccuIdent(), oneExpr), eh.AccuIdent())
result = eh.GlobalCall(operators.Equals, eh.AccuIdent(), oneExpr)
default:
return nil, &common.Error{Message: fmt.Sprintf("unrecognized quantifier '%v'", kind)}
return nil, eh.NewError(args[0].GetId(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
}
return eh.Fold(v, target, AccumulatorName, init, condition, step, result), nil
}

View File

@ -25,6 +25,7 @@ type options struct {
macros map[string]Macro
populateMacroCalls bool
enableOptionalSyntax bool
enableVariadicOperatorASTs bool
}
// Option configures the behavior of the parser.
@ -125,3 +126,15 @@ func EnableOptionalSyntax(optionalSyntax bool) Option {
return nil
}
}
// EnableVariadicOperatorASTs enables a compact representation of chained like-kind commutative
// operators. e.g. `a || b || c || d` -> `call(op='||', args=[a, b, c, d])`
//
// The benefit of enabling variadic operators ASTs is a more compact representation deeply nested
// logic graphs.
func EnableVariadicOperatorASTs(varArgASTs bool) Option {
return func(opts *options) error {
opts.enableVariadicOperatorASTs = varArgASTs
return nil
}
}

View File

@ -89,8 +89,9 @@ func mustNewParser(opts ...Option) *Parser {
// Parse parses the expression represented by source and returns the result.
func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
errs := common.NewErrors(source)
impl := parser{
errors: &parseErrors{common.NewErrors(source)},
errors: &parseErrors{errs},
helper: newParserHelper(source),
macros: p.macros,
maxRecursionDepth: p.maxRecursionDepth,
@ -99,6 +100,7 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors
errorRecoveryLookaheadTokenLimit: p.errorRecoveryTokenLookaheadLimit,
populateMacroCalls: p.populateMacroCalls,
enableOptionalSyntax: p.enableOptionalSyntax,
enableVariadicOperatorASTs: p.enableVariadicOperatorASTs,
}
buf, ok := source.(runes.Buffer)
if !ok {
@ -115,7 +117,7 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors
return &exprpb.ParsedExpr{
Expr: e,
SourceInfo: impl.helper.getSourceInfo(),
}, impl.errors.Errors
}, errs
}
// reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid
@ -295,6 +297,7 @@ type parser struct {
errorRecoveryLookaheadTokenLimit int
populateMacroCalls bool
enableOptionalSyntax bool
enableVariadicOperatorASTs bool
}
var (
@ -357,9 +360,9 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
if val := recover(); val != nil {
switch err := val.(type) {
case *lookaheadLimitError:
p.errors.ReportError(common.NoLocation, err.Error())
p.errors.internalError(err.Error())
case *recursionError:
p.errors.ReportError(common.NoLocation, err.Error())
p.errors.internalError(err.Error())
case *tooManyErrors:
// do nothing
case *recoveryLimitError:
@ -449,7 +452,7 @@ func (p *parser) Visit(tree antlr.ParseTree) any {
// Report at least one error if the parser reaches an unknown parse element.
// Typically, this happens if the parser has already encountered a syntax error elsewhere.
if len(p.errors.GetErrors()) == 0 {
if p.errors.errorCount() == 0 {
txt := "<<nil>>"
if t != nil {
txt = fmt.Sprintf("<<%T>>", t)
@ -480,7 +483,7 @@ func (p *parser) VisitExpr(ctx *gen.ExprContext) any {
// Visit a parse tree produced by CELParser#conditionalOr.
func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
b := newBalancer(p.helper, operators.LogicalOr, result)
l := p.newLogicManager(operators.LogicalOr, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
@ -488,15 +491,15 @@ func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
}
next := p.Visit(rest[i]).(*exprpb.Expr)
opID := p.helper.id(op)
b.addTerm(opID, next)
l.addTerm(opID, next)
}
return b.balance()
return l.toExpr()
}
// Visit a parse tree produced by CELParser#conditionalAnd.
func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any {
result := p.Visit(ctx.GetE()).(*exprpb.Expr)
b := newBalancer(p.helper, operators.LogicalAnd, result)
l := p.newLogicManager(operators.LogicalAnd, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
@ -504,9 +507,9 @@ func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any {
}
next := p.Visit(rest[i]).(*exprpb.Expr)
opID := p.helper.id(op)
b.addTerm(opID, next)
l.addTerm(opID, next)
}
return b.balance()
return l.toExpr()
}
// Visit a parse tree produced by CELParser#relation.
@ -867,18 +870,24 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string {
return text
}
func (p *parser) newLogicManager(function string, term *exprpb.Expr) *logicManager {
if p.enableVariadicOperatorASTs {
return newVariadicLogicManager(p.helper, function, term)
}
return newBalancingLogicManager(p.helper, function, term)
}
func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr {
var location common.Location
switch ctx.(type) {
err := p.helper.newExpr(ctx)
switch c := ctx.(type) {
case common.Location:
location = ctx.(common.Location)
location = c
case antlr.Token, antlr.ParserRuleContext:
err := p.helper.newExpr(ctx)
location = p.helper.getLocation(err.GetId())
}
err := p.helper.newExpr(ctx)
// Provide arguments to the report error.
p.errors.ReportError(location, format, args...)
p.errors.reportErrorAtID(err.GetId(), location, format, args...)
return err
}