rebase: bump the golang-dependencies group with 1 update

Bumps the golang-dependencies group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.16.0 to 0.17.0
- [Commits](https://github.com/golang/crypto/compare/v0.16.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot]
2023-12-18 20:31:00 +00:00
committed by mergify[bot]
parent 1ad79314f9
commit e5d9b68d36
398 changed files with 33924 additions and 10753 deletions

View File

@ -11,9 +11,11 @@ go_library(
"cost.go",
"env.go",
"errors.go",
"format.go",
"mapping.go",
"options.go",
"printer.go",
"scopes.go",
"standard.go",
"types.go",
],
@ -22,10 +24,13 @@ go_library(
deps = [
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/debug:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
@ -44,6 +49,7 @@ go_test(
"checker_test.go",
"cost_test.go",
"env_test.go",
"format_test.go",
],
embed = [
":go_default_library",

View File

@ -18,15 +18,13 @@ package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -37,8 +35,8 @@ type checker struct {
mappings *mapping
freeTypeVarCounter int
sourceInfo *exprpb.SourceInfo
types map[int64]*exprpb.Type
references map[int64]*exprpb.Reference
types map[int64]*types.Type
references map[int64]*ast.ReferenceInfo
}
// Check performs type checking, giving a typed AST.
@ -47,40 +45,38 @@ type checker struct {
// descriptions of protocol buffers, and a registry for errors.
// Returns a CheckedExpr proto, which might not be usable if
// there are errors in the error registry.
func Check(parsedExpr *exprpb.ParsedExpr,
source common.Source,
env *Env) (*exprpb.CheckedExpr, *common.Errors) {
func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) {
errs := common.NewErrors(source)
c := checker{
env: env,
errors: &typeErrors{common.NewErrors(source)},
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
sourceInfo: parsedExpr.GetSourceInfo(),
types: make(map[int64]*exprpb.Type),
references: make(map[int64]*exprpb.Reference),
types: make(map[int64]*types.Type),
references: make(map[int64]*ast.ReferenceInfo),
}
c.check(parsedExpr.GetExpr())
// Walk over the final type map substituting any type parameters either by their bound value or
// by DYN.
m := make(map[int64]*exprpb.Type)
for k, v := range c.types {
m[k] = substitute(c.mappings, v, true)
m := make(map[int64]*types.Type)
for id, t := range c.types {
m[id] = substitute(c.mappings, t, true)
}
return &exprpb.CheckedExpr{
return &ast.CheckedAST{
Expr: parsedExpr.GetExpr(),
SourceInfo: parsedExpr.GetSourceInfo(),
TypeMap: m,
ReferenceMap: c.references,
}, c.errors.Errors
}, errs
}
func (c *checker) check(e *exprpb.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
@ -113,53 +109,51 @@ func (c *checker) check(e *exprpb.Expr) {
case *exprpb.Expr_ComprehensionExpr:
c.checkComprehension(e)
default:
c.errors.ReportError(
c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e))
c.errors.unexpectedASTType(e.GetId(), c.location(e), e)
}
}
func (c *checker) checkInt64Literal(e *exprpb.Expr) {
c.setType(e, decls.Int)
c.setType(e, types.IntType)
}
func (c *checker) checkUint64Literal(e *exprpb.Expr) {
c.setType(e, decls.Uint)
c.setType(e, types.UintType)
}
func (c *checker) checkStringLiteral(e *exprpb.Expr) {
c.setType(e, decls.String)
c.setType(e, types.StringType)
}
func (c *checker) checkBytesLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bytes)
c.setType(e, types.BytesType)
}
func (c *checker) checkDoubleLiteral(e *exprpb.Expr) {
c.setType(e, decls.Double)
c.setType(e, types.DoubleType)
}
func (c *checker) checkBoolLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bool)
c.setType(e, types.BoolType)
}
func (c *checker) checkNullLiteral(e *exprpb.Expr) {
c.setType(e, decls.Null)
c.setType(e, types.NullType)
}
func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.GetName()
identExpr.Name = ident.Name()
return
}
c.setType(e, decls.Error)
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), identExpr.GetName())
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName())
}
func (c *checker) checkSelect(e *exprpb.Expr) {
@ -174,9 +168,9 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
identName := ident.GetName()
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
identName := ident.Name()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: identName,
@ -188,7 +182,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = decls.Bool
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
@ -200,16 +194,17 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) {
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field)
c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *exprpb.Type {
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
@ -218,38 +213,37 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
case kindMap:
resultType := types.ErrorType
switch targetType.Kind() {
case types.MapKind:
// Maps yield their value type as the selection result type.
mapType := targetType.GetMapType()
resultType = mapType.GetValueType()
case kindObject:
resultType = targetType.Parameters()[1]
case types.StructKind:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found {
resultType = fieldType.Type
if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found {
resultType = fieldType
}
case kindTypeParam:
case types.TypeParamKind:
// Set the operand type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, targetType)
c.isAssignable(types.DynType, targetType)
// Also, set the result type to DYN.
resultType = decls.Dyn
resultType = types.DynType
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType)
}
resultType = decls.Dyn
resultType = types.DynType
}
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return decls.NewOptionalType(resultType)
return types.NewOptionalType(resultType)
}
return resultType
}
@ -277,15 +271,14 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), fnName)
c.setType(e, decls.Error)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
call.Function = fn.GetName()
call.Function = fn.Name()
// Check to see whether the overload resolves.
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
c.resolveOverloadOrError(e, fn, nil, args)
return
}
@ -303,8 +296,8 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
call.Target = nil
call.Function = fn.GetName()
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
call.Function = fn.Name()
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
@ -314,22 +307,21 @@ func (c *checker) checkCall(e *exprpb.Expr) {
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
c.resolveOverloadOrError(c.location(e), e, fn, target, args)
c.resolveOverloadOrError(e, fn, target, args)
return
}
// Function name not declared, record error.
c.errors.undeclaredReference(c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
loc common.Location,
e *exprpb.Expr,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) {
e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(loc, fn, target, args)
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
if resolution == nil {
c.setType(e, decls.Error)
c.setType(e, types.ErrorType)
return
}
// Overload found.
@ -338,10 +330,9 @@ func (c *checker) resolveOverloadOrError(
}
func (c *checker) resolveOverload(
loc common.Location,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
var argTypes []*exprpb.Type
var argTypes []*types.Type
if target != nil {
argTypes = append(argTypes, c.getType(target))
}
@ -349,55 +340,75 @@ func (c *checker) resolveOverload(
argTypes = append(argTypes, c.getType(arg))
}
var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().GetOverloads() {
var resultType *types.Type
var checkedRef *ast.ReferenceInfo
for _, overload := range fn.OverloadDecls() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
if c.env.isOverloadDisabled(overload.ID()) {
continue
}
// Ensure the call style for the overload matches.
if (target == nil && overload.GetIsInstanceFunction()) ||
(target != nil && !overload.GetIsInstanceFunction()) {
if (target == nil && overload.IsMemberFunction()) ||
(target != nil && !overload.IsMemberFunction()) {
// not a compatible call style.
continue
}
overloadType := decls.NewFunctionType(overload.ResultType, overload.Params...)
if len(overload.GetTypeParams()) > 0 {
// Alternative type-checking behavior when the logical operators are compacted into
// variadic AST representations.
if fn.Name() == operators.LogicalAnd || fn.Name() == operators.LogicalOr {
checkedRef = ast.NewFunctionReference(overload.ID())
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].GetId(),
c.locationByID(args[i].GetId()),
types.BoolType,
argType)
resultType = types.ErrorType
}
}
if isError(resultType) {
return nil
}
return newResolution(checkedRef, types.BoolType)
}
overloadType := newFunctionType(overload.ResultType(), overload.ArgTypes()...)
typeParams := overload.TypeParams()
if len(typeParams) != 0 {
// Instantiate overload's type with fresh type variables.
substitutions := newMapping()
for _, typePar := range overload.GetTypeParams() {
substitutions.add(decls.NewTypeParamType(typePar), c.newTypeVar())
for _, typePar := range typeParams {
substitutions.add(types.NewTypeParamType(typePar), c.newTypeVar())
}
overloadType = substitute(substitutions, overloadType, false)
}
candidateArgTypes := overloadType.GetFunction().GetArgTypes()
candidateArgTypes := overloadType.Parameters()[1:]
if c.isAssignableList(argTypes, candidateArgTypes) {
if checkedRef == nil {
checkedRef = newFunctionReference(overload.GetOverloadId())
checkedRef = ast.NewFunctionReference(overload.ID())
} else {
checkedRef.OverloadId = append(checkedRef.GetOverloadId(), overload.GetOverloadId())
checkedRef.AddOverload(overload.ID())
}
// First matching overload, determines result type.
fnResultType := substitute(c.mappings, overloadType.GetFunction().GetResultType(), false)
fnResultType := substitute(c.mappings, overloadType.Parameters()[0], false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !proto.Equal(fnResultType, resultType) {
resultType = decls.Dyn
} else if !isDyn(resultType) && !fnResultType.IsExactType(resultType) {
resultType = types.DynType
}
}
}
if resultType == nil {
for i, arg := range argTypes {
argTypes[i] = substitute(c.mappings, arg, true)
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error
c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
@ -406,7 +417,7 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemsType *exprpb.Type
var elemsType *types.Type
optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
@ -419,16 +430,16 @@ func (c *checker) checkCreateList(e *exprpb.Expr) {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType)
c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(c.location(e), elemsType, elemType)
elemsType = c.joinTypes(e, elemsType, elemType)
}
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemsType = c.newTypeVar()
}
c.setType(e, decls.NewListType(elemsType))
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@ -442,12 +453,12 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
var mapKeyType *exprpb.Type
var mapValueType *exprpb.Type
var mapKeyType *types.Type
var mapValueType *types.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
c.check(key)
mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key))
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
val := ent.GetValue()
c.check(val)
@ -456,50 +467,54 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType)
c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(c.location(val), mapValueType, valType)
mapValueType = c.joinTypes(val, mapValueType, valType)
}
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, decls.NewMapType(mapKeyType, mapValueType))
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
// Determine the type of the message.
messageType := decls.Error
decl := c.env.LookupIdent(msgVal.GetMessageName())
if decl == nil {
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.GetMessageName())
if ident == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), msgVal.GetMessageName())
e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
msgVal.MessageName = decl.GetName()
c.setReference(e, newIdentReference(decl.GetName(), nil))
ident := decl.GetIdent()
identKind := kindOf(ident.GetType())
if identKind != kindError {
if identKind != kindType {
c.errors.notAType(c.location(e), ident.GetType())
typeName := ident.Name()
msgVal.MessageName = typeName
c.setReference(e, ast.NewIdentReference(ident.Name(), nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName())
} else {
messageType = ident.GetType().GetType()
if kindOf(messageType) != kindObject {
c.errors.notAMessageType(c.location(e), messageType)
messageType = decls.Error
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
// In this context, the type is being instantiated by its protobuf name which
// is not ideal or recommended, but some users expect this to work.
if isWellKnownType(resultType) {
typeName = getWellKnownTypeName(resultType)
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
}
if isObjectWellKnownType(messageType) {
c.setType(e, getObjectWellKnownType(messageType))
} else {
c.setType(e, messageType)
}
c.setType(e, resultType)
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
@ -507,10 +522,10 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
value := ent.GetValue()
c.check(value)
fieldType := decls.Error
ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field)
fieldType := types.ErrorType
ft, found := c.lookupFieldType(ent.GetId(), typeName, field)
if found {
fieldType = ft.Type
fieldType = ft
}
valType := c.getType(value)
@ -518,11 +533,11 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType)
c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType)
c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType)
}
}
}
@ -533,36 +548,36 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
c.check(comp.GetAccuInit())
accuType := c.getType(comp.GetAccuInit())
rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false)
var varType *exprpb.Type
var varType *types.Type
switch kindOf(rangeType) {
case kindList:
varType = rangeType.GetListType().GetElemType()
case kindMap:
switch rangeType.Kind() {
case types.ListKind:
varType = rangeType.Parameters()[0]
case types.MapKind:
// Ranges over the keys.
varType = rangeType.GetMapType().GetKeyType()
case kindDyn, kindError, kindTypeParam:
varType = rangeType.Parameters()[0]
case types.DynKind, types.ErrorKind, types.TypeParamKind:
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, rangeType)
c.isAssignable(types.DynType, rangeType)
// Set the range iteration variable to type DYN as well.
varType = decls.Dyn
varType = types.DynType
default:
c.errors.notAComprehensionRange(c.location(comp.GetIterRange()), rangeType)
varType = decls.Error
c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetAccuVar(), accuType))
c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetIterVar(), varType))
c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.GetLoopCondition())
c.assertType(comp.GetLoopCondition(), decls.Bool)
c.assertType(comp.GetLoopCondition(), types.BoolType)
c.check(comp.GetLoopStep())
c.assertType(comp.GetLoopStep(), accuType)
// Exit the loop's block scope before checking the result.
@ -574,9 +589,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(loc common.Location,
previous *exprpb.Type,
current *exprpb.Type) *exprpb.Type {
func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
@ -584,23 +597,23 @@ func (c *checker) joinTypes(loc common.Location,
return mostGeneral(previous, current)
}
if c.dynAggregateLiteralElementTypesEnabled() {
return decls.Dyn
return types.DynType
}
c.errors.typeMismatch(loc, previous, current)
return decls.Error
c.errors.typeMismatch(e.GetId(), c.location(e), previous, current)
return types.ErrorType
}
func (c *checker) dynAggregateLiteralElementTypesEnabled() bool {
return c.env.aggLitElemType == dynElementType
}
func (c *checker) newTypeVar() *exprpb.Type {
func (c *checker) newTypeVar() *types.Type {
id := c.freeTypeVarCounter
c.freeTypeVarCounter++
return decls.NewTypeParamType(fmt.Sprintf("_var%d", id))
return types.NewTypeParamType(fmt.Sprintf("_var%d", id))
}
func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
func (c *checker) isAssignable(t1, t2 *types.Type) bool {
subs := isAssignable(c.mappings, t1, t2)
if subs != nil {
c.mappings = subs
@ -610,7 +623,7 @@ func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
return false
}
func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
subs := isAssignableList(c.mappings, l1, l2)
if subs != nil {
c.mappings = subs
@ -620,57 +633,52 @@ func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
return false
}
func (c *checker) lookupFieldType(l common.Location, messageType string, fieldName string) (*ref.FieldType, bool) {
if _, found := c.env.provider.FindType(messageType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(l, messageType)
return nil, false
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
if ft, found := c.env.provider.FindFieldType(messageType, fieldName); found {
return ft, found
}
c.errors.undefinedField(l, fieldName)
return nil, false
return "", false
}
func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) {
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
func (c *checker) setType(e *exprpb.Expr, t *types.Type) {
if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t)
return
}
c.types[e.GetId()] = t
}
func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {
func (c *checker) getType(e *exprpb.Expr) *types.Type {
return c.types[e.GetId()]
}
func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) {
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r)
func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) {
if old, found := c.references[e.GetId()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r)
return
}
c.references[e.GetId()] = r
}
func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) {
func (c *checker) assertType(e *exprpb.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(c.location(e), t, c.getType(e))
c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e))
}
}
type overloadResolution struct {
Reference *exprpb.Reference
Type *exprpb.Type
Type *types.Type
Reference *ast.ReferenceInfo
}
func newResolution(checkedRef *exprpb.Reference, t *exprpb.Type) *overloadResolution {
func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
return &overloadResolution{
Reference: checkedRef,
Reference: r,
Type: t,
}
}
@ -697,10 +705,56 @@ func (c *checker) locationByID(id int64) common.Location {
return common.NoLocation
}
func newIdentReference(name string, value *exprpb.Constant) *exprpb.Reference {
return &exprpb.Reference{Name: name, Value: value}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {
if _, found := c.env.provider.FindStructType(structType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(exprID, c.locationByID(exprID), structType)
return nil, false
}
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
return ft.Type, found
}
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
return nil, false
}
func newFunctionReference(overloads ...string) *exprpb.Reference {
return &exprpb.Reference{OverloadId: overloads}
func isWellKnownType(t *types.Type) bool {
switch t.Kind() {
case types.AnyKind, types.TimestampKind, types.DurationKind, types.DynKind, types.NullTypeKind:
return true
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind:
return t.IsAssignableType(types.NullType)
case types.ListKind:
return t.Parameters()[0] == types.DynType
case types.MapKind:
return t.Parameters()[0] == types.StringType && t.Parameters()[1] == types.DynType
}
return false
}
func getWellKnownTypeName(t *types.Type) string {
if name, found := wellKnownTypes[t.Kind()]; found {
return name
}
return ""
}
var (
wellKnownTypes = map[types.Kind]string{
types.AnyKind: "google.protobuf.Any",
types.BoolKind: "google.protobuf.BoolValue",
types.BytesKind: "google.protobuf.BytesValue",
types.DoubleKind: "google.protobuf.DoubleValue",
types.DurationKind: "google.protobuf.Duration",
types.DynKind: "google.protobuf.Value",
types.IntKind: "google.protobuf.Int64Value",
types.ListKind: "google.protobuf.ListValue",
types.NullTypeKind: "google.protobuf.NullValue",
types.MapKind: "google.protobuf.Struct",
types.StringKind: "google.protobuf.StringValue",
types.TimestampKind: "google.protobuf.Timestamp",
types.UintKind: "google.protobuf.UInt64Value",
}
)

View File

@ -18,7 +18,9 @@ import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -54,7 +56,7 @@ type AstNode interface {
// The first path element is a variable. All subsequent path elements are one of: field name, '@items', '@keys', '@values'.
Path() []string
// Type returns the deduced type of the AstNode.
Type() *exprpb.Type
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
@ -66,7 +68,7 @@ type AstNode interface {
type astNode struct {
path []string
t *exprpb.Type
t *types.Type
expr *exprpb.Expr
derivedSize *SizeEstimate
}
@ -75,7 +77,7 @@ func (e astNode) Path() []string {
return e.path
}
func (e astNode) Type() *exprpb.Type {
func (e astNode) Type() *types.Type {
return e.t
}
@ -228,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
@ -240,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}
var (
@ -258,9 +264,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
@ -289,6 +296,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
@ -301,15 +309,30 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}
// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate
// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checker,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
@ -317,7 +340,7 @@ func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOpti
return CostEstimate{}, err
}
}
return c.cost(checker.GetExpr()), nil
return c.cost(checker.Expr), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@ -351,10 +374,10 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].GetTypeKind().(type) {
case *exprpb.Type_ListType_:
switch c.checkedAST.TypeMap[iterRange].Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
@ -378,8 +401,8 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
switch kindOf(targetType) {
case kindMap, kindObject, kindTypeParam:
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}
@ -403,8 +426,8 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedExpr.ReferenceMap[e.GetId()]
if ref == nil || len(ref.GetOverloadId()) == 0 {
ref := c.checkedAST.ReferenceMap[e.GetId()]
if ref == nil || len(ref.OverloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
@ -417,7 +440,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
for _, overload := range ref.OverloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
@ -530,7 +553,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}
if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
@ -641,8 +671,8 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *exprpb.Type {
return c.checkedExpr.TypeMap[e.GetId()]
func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.TypeMap[e.GetId()]
}
func (c *coster) getPath(e *exprpb.Expr) []string {
@ -663,22 +693,24 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
if size, ok := c.computedSizes[e.GetId()]; ok {
derivedSize = &size
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
return &astNode{
path: path,
t: c.getType(e),
expr: e,
derivedSize: derivedSize}
}
// isScalar returns true if the given type is known to be of a constant size at
// compile time. isScalar will return false for strings (they are variable-width)
// in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time).
func isScalar(t *exprpb.Type) bool {
switch kindOf(t) {
case kindPrimitive:
if t.GetPrimitive() != exprpb.Type_STRING && t.GetPrimitive() != exprpb.Type_BYTES {
return true
}
case kindWellKnown:
if t.GetWellKnown() == exprpb.Type_DURATION || t.GetWellKnown() == exprpb.Type_TIMESTAMP {
return true
}
func isScalar(t *types.Type) bool {
switch t.Kind() {
case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, types.TimestampKind, types.UintKind:
return true
}
return false
}
var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)

View File

@ -9,7 +9,6 @@ go_library(
name = "go_default_library",
srcs = [
"decls.go",
"scopes.go",
],
importpath = "github.com/google/cel-go/checker/decls",
deps = [

View File

@ -18,17 +18,11 @@ import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type aggregateLiteralElementType int
@ -76,15 +70,15 @@ var (
// which can be used to assist with type-checking.
type Env struct {
container *containers.Container
provider ref.TypeProvider
declarations *decls.Scopes
provider types.Provider
declarations *Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
}
// NewEnv returns a new *Env with the given parameters.
func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...Option) (*Env, error) {
declarations := decls.NewScopes()
func NewEnv(container *containers.Container, provider types.Provider, opts ...Option) (*Env, error) {
declarations := newScopes()
declarations.Push()
envOptions := &options{}
@ -113,24 +107,31 @@ func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...
}, nil
}
// Add adds new Decl protos to the Env.
// Returns an error for identifier redeclarations.
func (e *Env) Add(decls ...*exprpb.Decl) error {
// AddIdents configures the checker with a list of variable declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddIdents(declarations ...*decls.VariableDecl) error {
errMsgs := make([]errorMsg, 0)
for _, decl := range decls {
switch decl.DeclKind.(type) {
case *exprpb.Decl_Ident:
errMsgs = append(errMsgs, e.addIdent(sanitizeIdent(decl)))
case *exprpb.Decl_Function:
errMsgs = append(errMsgs, e.setFunction(sanitizeFunction(decl))...)
}
for _, d := range declarations {
errMsgs = append(errMsgs, e.addIdent(d))
}
return formatError(errMsgs)
}
// AddFunctions configures the checker with a list of function declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddFunctions(declarations ...*decls.FunctionDecl) error {
errMsgs := make([]errorMsg, 0)
for _, d := range declarations {
errMsgs = append(errMsgs, e.setFunction(d)...)
}
return formatError(errMsgs)
}
// LookupIdent returns a Decl proto for typeName as an identifier in the Env.
// Returns nil if no such identifier is found in the Env.
func (e *Env) LookupIdent(name string) *exprpb.Decl {
func (e *Env) LookupIdent(name string) *decls.VariableDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if ident := e.declarations.FindIdent(candidate); ident != nil {
return ident
@ -139,8 +140,8 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import the name as a reference to a message type. If found,
// the declaration is added to the outest (global) scope of the
// environment, so next time we can access it faster.
if t, found := e.provider.FindType(candidate); found {
decl := decls.NewVar(candidate, t)
if t, found := e.provider.FindStructType(candidate); found {
decl := decls.NewVariable(candidate, t)
e.declarations.AddIdent(decl)
return decl
}
@ -148,11 +149,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import this as an enum value by splitting the name in a type prefix and
// the enum inside.
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {
decl := decls.NewIdent(candidate,
decls.Int,
&exprpb.Constant{
ConstantKind: &exprpb.Constant_Int64Value{
Int64Value: int64(enumValue.(types.Int))}})
decl := decls.NewConstant(candidate, types.IntType, enumValue)
e.declarations.AddIdent(decl)
return decl
}
@ -162,7 +159,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// LookupFunction returns a Decl proto for typeName as a function in env.
// Returns nil if no such function is found in env.
func (e *Env) LookupFunction(name string) *exprpb.Decl {
func (e *Env) LookupFunction(name string) *decls.FunctionDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if fn := e.declarations.FindFunction(candidate); fn != nil {
return fn
@ -171,88 +168,46 @@ func (e *Env) LookupFunction(name string) *exprpb.Decl {
return nil
}
// addOverload adds overload to function declaration f.
// Returns one or more errorMsg values if the overload overlaps with an existing overload or macro.
func (e *Env) addOverload(f *exprpb.Decl, overload *exprpb.Decl_FunctionDecl_Overload) []errorMsg {
errMsgs := make([]errorMsg, 0)
function := f.GetFunction()
emptyMappings := newMapping()
overloadFunction := decls.NewFunctionType(overload.GetResultType(),
overload.GetParams()...)
overloadErased := substitute(emptyMappings, overloadFunction, true)
for _, existing := range function.GetOverloads() {
existingFunction := decls.NewFunctionType(existing.GetResultType(), existing.GetParams()...)
existingErased := substitute(emptyMappings, existingFunction, true)
overlap := isAssignable(emptyMappings, overloadErased, existingErased) != nil ||
isAssignable(emptyMappings, existingErased, overloadErased) != nil
if overlap &&
overload.GetIsInstanceFunction() == existing.GetIsInstanceFunction() {
errMsgs = append(errMsgs,
overlappingOverloadError(f.Name,
overload.GetOverloadId(), overloadFunction,
existing.GetOverloadId(), existingFunction))
}
}
for _, macro := range parser.AllMacros {
if macro.Function() == f.Name &&
macro.IsReceiverStyle() == overload.GetIsInstanceFunction() &&
macro.ArgCount() == len(overload.GetParams()) {
errMsgs = append(errMsgs, overlappingMacroError(f.Name, macro.ArgCount()))
}
}
if len(errMsgs) > 0 {
return errMsgs
}
function.Overloads = append(function.GetOverloads(), overload)
return errMsgs
}
// setFunction adds the function Decl to the Env.
// Adds a function decl if one doesn't already exist, then adds all overloads from the Decl.
// If overload overlaps with an existing overload, adds to the errors in the Env instead.
func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
errorMsgs := make([]errorMsg, 0)
overloads := decl.GetFunction().GetOverloads()
current := e.declarations.FindFunction(decl.Name)
if current == nil {
//Add the function declaration without overloads and check the overloads below.
current = decls.NewFunction(decl.Name)
} else {
existingOverloads := map[string]*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range current.GetFunction().GetOverloads() {
existingOverloads[overload.GetOverloadId()] = overload
func (e *Env) setFunction(fn *decls.FunctionDecl) []errorMsg {
errMsgs := make([]errorMsg, 0)
current := e.declarations.FindFunction(fn.Name())
if current != nil {
var err error
current, err = current.Merge(fn)
if err != nil {
return append(errMsgs, errorMsg(err.Error()))
}
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !overloadsEqual(existing, overload) {
newOverloads = append(newOverloads, overload)
} else {
current = fn
}
for _, overload := range current.OverloadDecls() {
for _, macro := range parser.AllMacros {
if macro.Function() == current.Name() &&
macro.IsReceiverStyle() == overload.IsMemberFunction() &&
macro.ArgCount() == len(overload.ArgTypes()) {
errMsgs = append(errMsgs, overlappingMacroError(current.Name(), macro.ArgCount()))
}
}
overloads = newOverloads
if len(newOverloads) == 0 {
return errorMsgs
if len(errMsgs) > 0 {
return errMsgs
}
// Copy on write since we don't know where this original definition came from.
current = proto.Clone(current).(*exprpb.Decl)
}
e.declarations.SetFunction(current)
for _, overload := range overloads {
errorMsgs = append(errorMsgs, e.addOverload(current, overload)...)
}
return errorMsgs
return errMsgs
}
// addIdent adds the Decl to the declarations in the Env.
// Returns a non-empty errorMsg if the identifier is already declared in the scope.
func (e *Env) addIdent(decl *exprpb.Decl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name)
func (e *Env) addIdent(decl *decls.VariableDecl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name())
if current != nil {
if proto.Equal(current, decl) {
if current.DeclarationIsEquivalent(decl) {
return ""
}
return overlappingIdentifierError(decl.Name)
return overlappingIdentifierError(decl.Name())
}
e.declarations.AddIdent(decl)
return ""
@ -264,111 +219,9 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found
}
// overloadsEqual returns whether two overloads have identical signatures.
//
// type parameter names are ignored as they may be specified in any order and have no bearing on overload
// equivalence
func overloadsEqual(o1, o2 *exprpb.Decl_FunctionDecl_Overload) bool {
return o1.GetOverloadId() == o2.GetOverloadId() &&
o1.GetIsInstanceFunction() == o2.GetIsInstanceFunction() &&
paramsEqual(o1.GetParams(), o2.GetParams()) &&
proto.Equal(o1.GetResultType(), o2.GetResultType())
}
// paramsEqual returns whether two lists have equal length and all types are equal
func paramsEqual(p1, p2 []*exprpb.Type) bool {
if len(p1) != len(p2) {
return false
}
for i, a := range p1 {
b := p2[i]
if !proto.Equal(a, b) {
return false
}
}
return true
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {
fn := decl.GetFunction()
// Determine whether the declaration requires replacements from proto-based message type
// references to well-known CEL type references.
var needsSanitizing bool
for _, o := range fn.GetOverloads() {
if isObjectWellKnownType(o.GetResultType()) {
needsSanitizing = true
break
}
for _, p := range o.GetParams() {
if isObjectWellKnownType(p) {
needsSanitizing = true
break
}
}
}
// Early return if the declaration requires no modification.
if !needsSanitizing {
return decl
}
// Sanitize all of the overloads if any overload requires an update to its type references.
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(fn.GetOverloads()))
for i, o := range fn.GetOverloads() {
rt := o.GetResultType()
if isObjectWellKnownType(rt) {
rt = getObjectWellKnownType(rt)
}
params := make([]*exprpb.Type, len(o.GetParams()))
copy(params, o.GetParams())
for j, p := range params {
if isObjectWellKnownType(p) {
params[j] = getObjectWellKnownType(p)
}
}
// If sanitized, replace the overload definition.
if o.IsInstanceFunction {
overloads[i] =
decls.NewInstanceOverload(o.GetOverloadId(), params, rt)
} else {
overloads[i] =
decls.NewOverload(o.GetOverloadId(), params, rt)
}
}
return decls.NewFunction(decl.GetName(), overloads...)
}
// sanitizeIdent replaces the identifier's well-known types referenced by message name with
// references to CEL built-in type instances.
func sanitizeIdent(decl *exprpb.Decl) *exprpb.Decl {
id := decl.GetIdent()
t := id.GetType()
if !isObjectWellKnownType(t) {
return decl
}
return decls.NewIdent(decl.GetName(), getObjectWellKnownType(t), id.GetValue())
}
// isObjectWellKnownType returns true if the input type is an OBJECT type with a message name
// that corresponds the message name of a built-in CEL type.
func isObjectWellKnownType(t *exprpb.Type) bool {
if kindOf(t) != kindObject {
return false
}
_, found := pb.CheckedWellKnowns[t.GetMessageType()]
return found
}
// getObjectWellKnownType returns the built-in CEL type declaration for input type's message name.
func getObjectWellKnownType(t *exprpb.Type) *exprpb.Type {
return pb.CheckedWellKnowns[t.GetMessageType()]
}
// validatedDeclarations returns a reference to the validated variable and function declaration scope stack.
// must be copied before use.
func (e *Env) validatedDeclarations() *decls.Scopes {
func (e *Env) validatedDeclarations() *Scopes {
return e.declarations
}
@ -402,19 +255,6 @@ func overlappingIdentifierError(name string) errorMsg {
return errorMsg(fmt.Sprintf("overlapping identifier for name '%s'", name))
}
func overlappingOverloadError(name string,
overloadID1 string, f1 *exprpb.Type,
overloadID2 string, f2 *exprpb.Type) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping overload for name '%s' (type '%s' with overloadId: '%s' "+
"cannot be distinguished from '%s' with overloadId: '%s')",
name,
FormatCheckedType(f1),
overloadID1,
FormatCheckedType(f2),
overloadID2))
}
func overlappingMacroError(name string, argCount int) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping macro for name '%s' with %d args", name, argCount))

View File

@ -15,82 +15,78 @@
package checker
import (
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// typeErrors is a specialization of Errors.
type typeErrors struct {
*common.Errors
errs *common.Errors
}
func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) {
e.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container)
func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, field, value *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCELType(field), FormatCELType(value))
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) {
e.ReportError(l, "type '%s' does not support field selection", t)
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) undefinedField(l common.Location, field string) {
e.ReportError(l, "undefined field '%s'", field)
func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
signature := formatFunctionDeclType(nil, args, isInstance)
e.errs.ReportErrorAtID(id, l, "found no matching overload for '%s' applied to '%s'", name, signature)
}
func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) {
signature := formatFunction(nil, args, isInstance)
e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature)
func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCELType(t))
}
func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t)
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}
func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s' is not a message type", FormatCheckedType(t))
func (e *typeErrors) notAType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a type", typeName)
}
func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) {
e.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCheckedType(field), FormatCheckedType(value))
func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}
func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) {
e.ReportError(l, "[internal] unexpected failed resolution of '%s'", typeName)
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) {
e.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCheckedType(t))
func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "type '%s' does not support field selection", FormatCELType(t))
}
func (e *typeErrors) typeMismatch(l common.Location, expected *exprpb.Type, actual *exprpb.Type) {
e.ReportError(l, "expected type '%s' but found '%s'",
FormatCheckedType(expected), FormatCheckedType(actual))
func (e *typeErrors) typeMismatch(id int64, l common.Location, expected, actual *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type '%s' but found '%s'",
FormatCELType(expected), FormatCELType(actual))
}
func formatFunction(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += FormatCheckedType(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += FormatCheckedType(arg)
}
result += ")"
if resultType != nil {
result += " -> "
result += FormatCheckedType(resultType)
}
return result
func (e *typeErrors) undefinedField(id int64, l common.Location, field string) {
e.errs.ReportErrorAtID(id, l, "undefined field '%s'", field)
}
func (e *typeErrors) undeclaredReference(id int64, l common.Location, container string, name string) {
e.errs.ReportErrorAtID(id, l, "undeclared reference to '%s' (in container '%s')", name, container)
}
func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
}

216
vendor/github.com/google/cel-go/checker/format.go generated vendored Normal file
View File

@ -0,0 +1,216 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
import (
"fmt"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunctionExprType(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil || t.GetType().GetTypeKind() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(chkdecls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
type formatter func(any) string
// FormatCELType formats a types.Type value to a string representation.
//
// The type formatting is identical to FormatCheckedType.
func FormatCELType(t any) string {
dt := t.(*types.Type)
switch dt.Kind() {
case types.AnyKind:
return "any"
case types.DurationKind:
return "duration"
case types.ErrorKind:
return "!error!"
case types.NullTypeKind:
return "null"
case types.TimestampKind:
return "timestamp"
case types.TypeParamKind:
return dt.TypeName()
case types.OpaqueKind:
if dt.TypeName() == "function" {
// There is no explicit function type in the new types representation, so information like
// whether the function is a member function is absent.
return formatFunctionDeclType(dt.Parameters()[0], dt.Parameters()[1:], false)
}
case types.UnspecifiedKind:
return ""
}
if len(dt.Parameters()) == 0 {
return dt.DeclaredTypeName()
}
paramTypeNames := make([]string, 0, len(dt.Parameters()))
for _, p := range dt.Parameters() {
paramTypeNames = append(paramTypeNames, FormatCELType(p))
}
return fmt.Sprintf("%s(%s)", dt.TypeName(), strings.Join(paramTypeNames, ", "))
}
func formatExprType(t any) string {
if t == nil {
return ""
}
return FormatCheckedType(t.(*exprpb.Type))
}
func formatFunctionExprType(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
return formatFunctionInternal[*exprpb.Type](resultType, argTypes, isInstance, formatExprType)
}
func formatFunctionDeclType(resultType *types.Type, argTypes []*types.Type, isInstance bool) string {
return formatFunctionInternal[*types.Type](resultType, argTypes, isInstance, FormatCELType)
}
func formatFunctionInternal[T any](resultType T, argTypes []T, isInstance bool, format formatter) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += format(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += format(arg)
}
result += ")"
rt := format(resultType)
if rt != "" {
result += " -> "
result += rt
}
return result
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}

View File

@ -15,25 +15,25 @@
package checker
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
type mapping struct {
mapping map[string]*exprpb.Type
mapping map[string]*types.Type
}
func newMapping() *mapping {
return &mapping{
mapping: make(map[string]*exprpb.Type),
mapping: make(map[string]*types.Type),
}
}
func (m *mapping) add(from *exprpb.Type, to *exprpb.Type) {
m.mapping[typeKey(from)] = to
func (m *mapping) add(from, to *types.Type) {
m.mapping[FormatCELType(from)] = to
}
func (m *mapping) find(from *exprpb.Type) (*exprpb.Type, bool) {
if r, found := m.mapping[typeKey(from)]; found {
func (m *mapping) find(from *types.Type) (*types.Type, bool) {
if r, found := m.mapping[FormatCELType(from)]; found {
return r, found
}
return nil, false

View File

@ -14,12 +14,10 @@
package checker
import "github.com/google/cel-go/checker/decls"
type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *decls.Scopes
validatedDeclarations *Scopes
}
// Option is a functional option for configuring the type-checker
@ -34,15 +32,6 @@ func CrossTypeNumericComparisons(enabled bool) Option {
}
}
// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all
// have the same type.
func HomogeneousAggregateLiterals(enabled bool) Option {
return func(opts *options) error {
opts.homogeneousAggregateLiterals = enabled
return nil
}
}
// ValidatedDeclarations provides a references to validated declarations which will be copied
// into new checker instances.
func ValidatedDeclarations(env *Env) Option {

View File

@ -15,6 +15,8 @@
package checker
import (
"sort"
"github.com/google/cel-go/common/debug"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -47,6 +49,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string {
if len(ref.GetOverloadId()) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
if i == 0 {
result += "^"

View File

@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package decls
package checker
import exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
import (
"github.com/google/cel-go/common/decls"
)
// Scopes represents nested Decl sets where the Scopes value contains a Groups containing all
// identifiers in scope and an optional parent representing outer scopes.
@ -25,9 +27,9 @@ type Scopes struct {
scopes *Group
}
// NewScopes creates a new, empty Scopes.
// newScopes creates a new, empty Scopes.
// Some operations can't be safely performed until a Group is added with Push.
func NewScopes() *Scopes {
func newScopes() *Scopes {
return &Scopes{
scopes: newGroup(),
}
@ -35,7 +37,7 @@ func NewScopes() *Scopes {
// Copy creates a copy of the current Scopes values, including a copy of its parent if non-nil.
func (s *Scopes) Copy() *Scopes {
cpy := NewScopes()
cpy := newScopes()
if s == nil {
return cpy
}
@ -66,14 +68,14 @@ func (s *Scopes) Pop() *Scopes {
// AddIdent adds the ident Decl in the current scope.
// Note: If the name collides with an existing identifier in the scope, the Decl is overwritten.
func (s *Scopes) AddIdent(decl *exprpb.Decl) {
s.scopes.idents[decl.Name] = decl
func (s *Scopes) AddIdent(decl *decls.VariableDecl) {
s.scopes.idents[decl.Name()] = decl
}
// FindIdent finds the first ident Decl with a matching name in Scopes, or nil if one cannot be
// found.
// Note: The search is performed from innermost to outermost.
func (s *Scopes) FindIdent(name string) *exprpb.Decl {
func (s *Scopes) FindIdent(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@ -86,7 +88,7 @@ func (s *Scopes) FindIdent(name string) *exprpb.Decl {
// FindIdentInScope finds the first ident Decl with a matching name in the current Scopes value, or
// nil if one does not exist.
// Note: The search is only performed on the current scope and does not search outer scopes.
func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
func (s *Scopes) FindIdentInScope(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@ -95,14 +97,14 @@ func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
// SetFunction adds the function Decl to the current scope.
// Note: Any previous entry for a function in the current scope with the same name is overwritten.
func (s *Scopes) SetFunction(fn *exprpb.Decl) {
s.scopes.functions[fn.Name] = fn
func (s *Scopes) SetFunction(fn *decls.FunctionDecl) {
s.scopes.functions[fn.Name()] = fn
}
// FindFunction finds the first function Decl with a matching name in Scopes.
// The search is performed from innermost to outermost.
// Returns nil if no such function in Scopes.
func (s *Scopes) FindFunction(name string) *exprpb.Decl {
func (s *Scopes) FindFunction(name string) *decls.FunctionDecl {
if fn, found := s.scopes.functions[name]; found {
return fn
}
@ -116,16 +118,16 @@ func (s *Scopes) FindFunction(name string) *exprpb.Decl {
// Contains separate namespaces for identifier and function Decls.
// (Should be named "Scope" perhaps?)
type Group struct {
idents map[string]*exprpb.Decl
functions map[string]*exprpb.Decl
idents map[string]*decls.VariableDecl
functions map[string]*decls.FunctionDecl
}
// copy creates a new Group instance with a shallow copy of the variables and functions.
// If callers need to mutate the exprpb.Decl definitions for a Function, they should copy-on-write.
func (g *Group) copy() *Group {
cpy := &Group{
idents: make(map[string]*exprpb.Decl, len(g.idents)),
functions: make(map[string]*exprpb.Decl, len(g.functions)),
idents: make(map[string]*decls.VariableDecl, len(g.idents)),
functions: make(map[string]*decls.FunctionDecl, len(g.functions)),
}
for n, id := range g.idents {
cpy.idents[n] = id
@ -139,7 +141,7 @@ func (g *Group) copy() *Group {
// newGroup creates a new Group with empty maps for identifiers and functions.
func newGroup() *Group {
return &Group{
idents: make(map[string]*exprpb.Decl),
functions: make(map[string]*exprpb.Decl),
idents: make(map[string]*decls.VariableDecl),
functions: make(map[string]*decls.FunctionDecl),
}
}

View File

@ -15,480 +15,21 @@
package checker
import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
standardDeclarations []*exprpb.Decl
)
func init() {
// Some shortcuts we use when building declarations.
paramA := decls.NewTypeParamType("A")
typeParamAList := []string{"A"}
listOfA := decls.NewListType(paramA)
paramB := decls.NewTypeParamType("B")
typeParamABList := []string{"A", "B"}
mapOfAB := decls.NewMapType(paramA, paramB)
var idents []*exprpb.Decl
for _, t := range []*exprpb.Type{
decls.Int, decls.Uint, decls.Bool,
decls.Double, decls.Bytes, decls.String} {
idents = append(idents,
decls.NewVar(FormatCheckedType(t), decls.NewTypeType(t)))
}
idents = append(idents,
decls.NewVar("list", decls.NewTypeType(listOfA)),
decls.NewVar("map", decls.NewTypeType(mapOfAB)),
decls.NewVar("null_type", decls.NewTypeType(decls.Null)),
decls.NewVar("type", decls.NewTypeType(decls.NewTypeType(nil))))
standardDeclarations = append(standardDeclarations, idents...)
standardDeclarations = append(standardDeclarations, []*exprpb.Decl{
// Booleans
decls.NewFunction(operators.Conditional,
decls.NewParameterizedOverload(overloads.Conditional,
[]*exprpb.Type{decls.Bool, paramA, paramA}, paramA,
typeParamAList)),
decls.NewFunction(operators.LogicalAnd,
decls.NewOverload(overloads.LogicalAnd,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalOr,
decls.NewOverload(overloads.LogicalOr,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalNot,
decls.NewOverload(overloads.LogicalNot,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.NotStrictlyFalse,
decls.NewOverload(overloads.NotStrictlyFalse,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.Equals,
decls.NewParameterizedOverload(overloads.Equals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
decls.NewFunction(operators.NotEquals,
decls.NewParameterizedOverload(overloads.NotEquals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
// Algebra.
decls.NewFunction(operators.Subtract,
decls.NewOverload(overloads.SubtractInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.SubtractUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.SubtractDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.SubtractTimestampTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Duration),
decls.NewOverload(overloads.SubtractTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.SubtractDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Multiply,
decls.NewOverload(overloads.MultiplyInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.MultiplyUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.MultiplyDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Divide,
decls.NewOverload(overloads.DivideInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.DivideUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.DivideDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Modulo,
decls.NewOverload(overloads.ModuloInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.ModuloUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint)),
decls.NewFunction(operators.Add,
decls.NewOverload(overloads.AddInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.AddUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.AddDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.AddString,
[]*exprpb.Type{decls.String, decls.String}, decls.String),
decls.NewOverload(overloads.AddBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bytes),
decls.NewParameterizedOverload(overloads.AddList,
[]*exprpb.Type{listOfA, listOfA}, listOfA,
typeParamAList),
decls.NewOverload(overloads.AddTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationTimestamp,
[]*exprpb.Type{decls.Duration, decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Negate,
decls.NewOverload(overloads.NegateInt64,
[]*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.NegateDouble,
[]*exprpb.Type{decls.Double}, decls.Double)),
// Index.
decls.NewFunction(operators.Index,
decls.NewParameterizedOverload(overloads.IndexList,
[]*exprpb.Type{listOfA, decls.Int}, paramA,
typeParamAList),
decls.NewParameterizedOverload(overloads.IndexMap,
[]*exprpb.Type{mapOfAB, paramA}, paramB,
typeParamABList)),
// Collections.
decls.NewFunction(overloads.Size,
decls.NewInstanceOverload(overloads.SizeStringInst,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.SizeBytesInst,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedInstanceOverload(overloads.SizeListInst,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedInstanceOverload(overloads.SizeMapInst,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList),
decls.NewOverload(overloads.SizeString,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.SizeBytes,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedOverload(overloads.SizeList,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedOverload(overloads.SizeMap,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList)),
decls.NewFunction(operators.In,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Deprecated 'in()' function.
decls.NewFunction(overloads.DeprecatedIn,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Conversions to type.
decls.NewFunction(overloads.TypeConvertType,
decls.NewParameterizedOverload(overloads.TypeConvertType,
[]*exprpb.Type{paramA}, decls.NewTypeType(paramA), typeParamAList)),
// Conversions to int.
decls.NewFunction(overloads.TypeConvertInt,
decls.NewOverload(overloads.IntToInt, []*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.UintToInt, []*exprpb.Type{decls.Uint}, decls.Int),
decls.NewOverload(overloads.DoubleToInt, []*exprpb.Type{decls.Double}, decls.Int),
decls.NewOverload(overloads.StringToInt, []*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.TimestampToInt, []*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewOverload(overloads.DurationToInt, []*exprpb.Type{decls.Duration}, decls.Int)),
// Conversions to uint.
decls.NewFunction(overloads.TypeConvertUint,
decls.NewOverload(overloads.UintToUint, []*exprpb.Type{decls.Uint}, decls.Uint),
decls.NewOverload(overloads.IntToUint, []*exprpb.Type{decls.Int}, decls.Uint),
decls.NewOverload(overloads.DoubleToUint, []*exprpb.Type{decls.Double}, decls.Uint),
decls.NewOverload(overloads.StringToUint, []*exprpb.Type{decls.String}, decls.Uint)),
// Conversions to double.
decls.NewFunction(overloads.TypeConvertDouble,
decls.NewOverload(overloads.DoubleToDouble, []*exprpb.Type{decls.Double}, decls.Double),
decls.NewOverload(overloads.IntToDouble, []*exprpb.Type{decls.Int}, decls.Double),
decls.NewOverload(overloads.UintToDouble, []*exprpb.Type{decls.Uint}, decls.Double),
decls.NewOverload(overloads.StringToDouble, []*exprpb.Type{decls.String}, decls.Double)),
// Conversions to bool.
decls.NewFunction(overloads.TypeConvertBool,
decls.NewOverload(overloads.BoolToBool, []*exprpb.Type{decls.Bool}, decls.Bool),
decls.NewOverload(overloads.StringToBool, []*exprpb.Type{decls.String}, decls.Bool)),
// Conversions to string.
decls.NewFunction(overloads.TypeConvertString,
decls.NewOverload(overloads.StringToString, []*exprpb.Type{decls.String}, decls.String),
decls.NewOverload(overloads.BoolToString, []*exprpb.Type{decls.Bool}, decls.String),
decls.NewOverload(overloads.IntToString, []*exprpb.Type{decls.Int}, decls.String),
decls.NewOverload(overloads.UintToString, []*exprpb.Type{decls.Uint}, decls.String),
decls.NewOverload(overloads.DoubleToString, []*exprpb.Type{decls.Double}, decls.String),
decls.NewOverload(overloads.BytesToString, []*exprpb.Type{decls.Bytes}, decls.String),
decls.NewOverload(overloads.TimestampToString, []*exprpb.Type{decls.Timestamp}, decls.String),
decls.NewOverload(overloads.DurationToString, []*exprpb.Type{decls.Duration}, decls.String)),
// Conversions to bytes.
decls.NewFunction(overloads.TypeConvertBytes,
decls.NewOverload(overloads.BytesToBytes, []*exprpb.Type{decls.Bytes}, decls.Bytes),
decls.NewOverload(overloads.StringToBytes, []*exprpb.Type{decls.String}, decls.Bytes)),
// Conversions to timestamps.
decls.NewFunction(overloads.TypeConvertTimestamp,
decls.NewOverload(overloads.TimestampToTimestamp,
[]*exprpb.Type{decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.StringToTimestamp,
[]*exprpb.Type{decls.String}, decls.Timestamp),
decls.NewOverload(overloads.IntToTimestamp,
[]*exprpb.Type{decls.Int}, decls.Timestamp)),
// Conversions to durations.
decls.NewFunction(overloads.TypeConvertDuration,
decls.NewOverload(overloads.DurationToDuration,
[]*exprpb.Type{decls.Duration}, decls.Duration),
decls.NewOverload(overloads.StringToDuration,
[]*exprpb.Type{decls.String}, decls.Duration),
decls.NewOverload(overloads.IntToDuration,
[]*exprpb.Type{decls.Int}, decls.Duration)),
// Conversions to Dyn.
decls.NewFunction(overloads.TypeConvertDyn,
decls.NewParameterizedOverload(overloads.ToDyn,
[]*exprpb.Type{paramA}, decls.Dyn,
typeParamAList)),
// String functions.
decls.NewFunction(overloads.Contains,
decls.NewInstanceOverload(overloads.ContainsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.EndsWith,
decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches,
decls.NewOverload(overloads.Matches,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith,
decls.NewInstanceOverload(overloads.StartsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
// Date/time functions.
decls.NewFunction(overloads.TimeGetFullYear,
decls.NewInstanceOverload(overloads.TimestampToYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetMonth,
decls.NewInstanceOverload(overloads.TimestampToMonth,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMonthWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfYear,
decls.NewInstanceOverload(overloads.TimestampToDayOfYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfMonth,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDate,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfWeek,
decls.NewInstanceOverload(overloads.TimestampToDayOfWeek,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfWeekWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetHours,
decls.NewInstanceOverload(overloads.TimestampToHours,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToHoursWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToHours,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMinutes,
decls.NewInstanceOverload(overloads.TimestampToMinutes,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMinutesWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMinutes,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetSeconds,
decls.NewInstanceOverload(overloads.TimestampToSeconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToSecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToSeconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMilliseconds,
decls.NewInstanceOverload(overloads.TimestampToMilliseconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMillisecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMilliseconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
// Relations.
decls.NewFunction(operators.Less,
decls.NewOverload(overloads.LessBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.LessEquals,
decls.NewOverload(overloads.LessEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Greater,
decls.NewOverload(overloads.GreaterBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.GreaterEquals,
decls.NewOverload(overloads.GreaterEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
}...)
// StandardFunctions returns the Decls for all functions in the evaluator.
//
// Deprecated: prefer stdlib.FunctionExprDecls()
func StandardFunctions() []*exprpb.Decl {
return stdlib.FunctionExprDecls()
}
// StandardDeclarations returns the Decls for all functions and constants in the evaluator.
func StandardDeclarations() []*exprpb.Decl {
return standardDeclarations
// StandardTypes returns the set of type identifiers for standard library types.
//
// Deprecated: prefer stdlib.TypeExprDecls()
func StandardTypes() []*exprpb.Decl {
return stdlib.TypeExprDecls()
}

View File

@ -15,154 +15,54 @@
package checker
import (
"fmt"
"strings"
"github.com/google/cel-go/checker/decls"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunction(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(decls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
// isDyn returns true if the input t is either type DYN or a well-known ANY message.
func isDyn(t *exprpb.Type) bool {
func isDyn(t *types.Type) bool {
// Note: object type values that are well-known and map to a DYN value in practice
// are sanitized prior to being added to the environment.
switch kindOf(t) {
case kindDyn:
switch t.Kind() {
case types.DynKind, types.AnyKind:
return true
case kindWellKnown:
return t.GetWellKnown() == exprpb.Type_ANY
default:
return false
}
}
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool {
func isDynOrError(t *types.Type) bool {
return isError(t) || isDyn(t)
}
func isError(t *exprpb.Type) bool {
return kindOf(t) == kindError
func isError(t *types.Type) bool {
return t.Kind() == types.ErrorKind
}
func isOptional(t *exprpb.Type) bool {
if kindOf(t) == kindAbstract {
at := t.GetAbstractType()
return at.GetName() == "optional"
func isOptional(t *types.Type) bool {
if t.Kind() == types.OpaqueKind {
return t.TypeName() == "optional"
}
return false
}
func maybeUnwrapOptional(t *exprpb.Type) (*exprpb.Type, bool) {
func maybeUnwrapOptional(t *types.Type) (*types.Type, bool) {
if isOptional(t) {
at := t.GetAbstractType()
return at.GetParameterTypes()[0], true
return t.Parameters()[0], true
}
return t, false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
return "", false
}
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
// A type is less specific if it matches the other type using the DYN type.
func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
kind1, kind2 := kindOf(t1), kindOf(t2)
func isEqualOrLessSpecific(t1, t2 *types.Type) bool {
kind1, kind2 := t1.Kind(), t2.Kind()
// The first type is less specific.
if isDyn(t1) || kind1 == kindTypeParam {
if isDyn(t1) || kind1 == types.TypeParamKind {
return true
}
// The first type is not less specific.
if isDyn(t2) || kind2 == kindTypeParam {
if isDyn(t2) || kind2 == types.TypeParamKind {
return false
}
// Types must be of the same kind to be equal.
@ -173,38 +73,34 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
// With limited exceptions for ANY and JSON values, the types must agree and be equivalent in
// order to return true.
switch kind1 {
case kindAbstract:
a1 := t1.GetAbstractType()
a2 := t2.GetAbstractType()
if a1.GetName() != a2.GetName() ||
len(a1.GetParameterTypes()) != len(a2.GetParameterTypes()) {
case types.OpaqueKind:
if t1.TypeName() != t2.TypeName() ||
len(t1.Parameters()) != len(t2.Parameters()) {
return false
}
for i, p1 := range a1.GetParameterTypes() {
if !isEqualOrLessSpecific(p1, a2.GetParameterTypes()[i]) {
for i, p1 := range t1.Parameters() {
if !isEqualOrLessSpecific(p1, t2.Parameters()[i]) {
return false
}
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
m1 := t1.GetMapType()
m2 := t2.GetMapType()
return isEqualOrLessSpecific(m1.GetKeyType(), m2.GetKeyType()) &&
isEqualOrLessSpecific(m1.GetValueType(), m2.GetValueType())
case kindType:
case types.ListKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0])
case types.MapKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0]) &&
isEqualOrLessSpecific(t1.Parameters()[1], t2.Parameters()[1])
case types.TypeKind:
return true
default:
return proto.Equal(t1, t2)
return t1.IsExactType(t2)
}
}
// / internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
func internalIsAssignable(m *mapping, t1, t2 *types.Type) bool {
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind2 == kindTypeParam {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind2 == types.TypeParamKind {
// If t2 is a valid type substitution for t1, return true.
valid, t2HasSub := isValidTypeSubstitution(m, t1, t2)
if valid {
@ -217,7 +113,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
// Otherwise, fall through to check whether t1 is a possible substitution for t2.
}
if kind1 == kindTypeParam {
if kind1 == types.TypeParamKind {
// Return whether t1 is a valid substitution for t2. If not, do no additional checks as the
// possible type substitutions have been searched in both directions.
valid, _ := isValidTypeSubstitution(m, t2, t1)
@ -228,40 +124,25 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
if isDynOrError(t1) || isDynOrError(t2) {
return true
}
// Preserve the nullness checks of the legacy type-checker.
if kind1 == types.NullTypeKind {
return internalIsAssignableNull(t2)
}
if kind2 == types.NullTypeKind {
return internalIsAssignableNull(t1)
}
// Test for when the types do not need to agree, but are more specific than dyn.
switch kind1 {
case kindNull:
return internalIsAssignableNull(t2)
case kindPrimitive:
return internalIsAssignablePrimitive(t1.GetPrimitive(), t2)
case kindWrapper:
return internalIsAssignable(m, decls.NewPrimitiveType(t1.GetWrapper()), t2)
default:
if kind1 != kind2 {
return false
}
}
// Test for when the types must agree.
switch kind1 {
// ERROR, TYPE_PARAM, and DYN handled above.
case kindAbstract:
return internalIsAssignableAbstractType(m, t1.GetAbstractType(), t2.GetAbstractType())
case kindFunction:
return internalIsAssignableFunction(m, t1.GetFunction(), t2.GetFunction())
case kindList:
return internalIsAssignable(m, t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
return internalIsAssignableMap(m, t1.GetMapType(), t2.GetMapType())
case kindObject:
return t1.GetMessageType() == t2.GetMessageType()
case kindType:
// A type is a type is a type, any additional parameterization of the
// type cannot affect method resolution or assignability.
return true
case kindWellKnown:
return t1.GetWellKnown() == t2.GetWellKnown()
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind,
types.AnyKind, types.DurationKind, types.TimestampKind,
types.StructKind:
return t1.IsAssignableType(t2)
case types.TypeKind:
return kind2 == types.TypeKind
case types.OpaqueKind, types.ListKind, types.MapKind:
return t1.Kind() == t2.Kind() && t1.TypeName() == t2.TypeName() &&
internalIsAssignableList(m, t1.Parameters(), t2.Parameters())
default:
return false
}
@ -274,16 +155,16 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {
func isValidTypeSubstitution(m *mapping, t1, t2 *types.Type) (valid, hasSub bool) {
// Early return if the t1 and t2 are the same instance.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind1 == kind2 && (t1 == t2 || proto.Equal(t1, t2)) {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind1 == kind2 && t1.IsExactType(t2) {
return true, true
}
if t2Sub, found := m.find(t2); found {
// Early return if t1 and t2Sub are the same instance as otherwise the mapping
// might mark a type as being a subtitution for itself.
if kind1 == kindOf(t2Sub) && (t1 == t2Sub || proto.Equal(t1, t2Sub)) {
if kind1 == t2Sub.Kind() && t1.IsExactType(t2Sub) {
return true, true
}
// If the types are compatible, pick the more general type and return true
@ -305,28 +186,10 @@ func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub boo
return false, false
}
// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, a2 *exprpb.Type_AbstractType) bool {
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}
// internalIsAssignableFunction returns true if the function return type and arg types are
// assignable.
func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool {
f1ArgTypes := flattenFunctionTypes(f1)
f2ArgTypes := flattenFunctionTypes(f2)
if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) {
return true
}
return false
}
// internalIsAssignableList returns true if the element types at each index in the list are
// assignable from l1[i] to l2[i]. The list lengths must also agree for the lists to be
// assignable.
func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func internalIsAssignableList(m *mapping, l1, l2 []*types.Type) bool {
if len(l1) != len(l2) {
return false
}
@ -338,41 +201,22 @@ func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type)
return true
}
// internalIsAssignableMap returns true if map m1 may be assigned to map m2.
func internalIsAssignableMap(m *mapping, m1 *exprpb.Type_MapType, m2 *exprpb.Type_MapType) bool {
if internalIsAssignableList(m,
[]*exprpb.Type{m1.GetKeyType(), m1.GetValueType()},
[]*exprpb.Type{m2.GetKeyType(), m2.GetValueType()}) {
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *types.Type) bool {
return isLegacyNullable(t) || t.IsAssignableType(types.NullType)
}
// isLegacyNullable preserves the null-ness compatibility of the original type-checker implementation.
func isLegacyNullable(t *types.Type) bool {
switch t.Kind() {
case types.OpaqueKind, types.StructKind, types.AnyKind, types.DurationKind, types.TimestampKind:
return true
}
return false
}
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *exprpb.Type) bool {
switch kindOf(t) {
case kindAbstract, kindObject, kindNull, kindWellKnown, kindWrapper:
return true
default:
return false
}
}
// internalIsAssignablePrimitive returns true if the target type is the same or if it is a wrapper
// for the primitive type.
func internalIsAssignablePrimitive(p exprpb.Type_PrimitiveType, target *exprpb.Type) bool {
switch kindOf(target) {
case kindPrimitive:
return p == target.GetPrimitive()
case kindWrapper:
return p == target.GetWrapper()
default:
return false
}
}
// isAssignable returns an updated type substitution mapping if t1 is assignable to t2.
func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
func isAssignable(m *mapping, t1, t2 *types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignable(mCopy, t1, t2) {
return mCopy
@ -381,7 +225,7 @@ func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
}
// isAssignableList returns an updated type substitution mapping if l1 is assignable to l2.
func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping {
func isAssignableList(m *mapping, l1, l2 []*types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignableList(mCopy, l1, l2) {
return mCopy
@ -389,44 +233,8 @@ func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping
return nil
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}
// mostGeneral returns the more general of two types which are known to unify.
func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
func mostGeneral(t1, t2 *types.Type) *types.Type {
if isEqualOrLessSpecific(t1, t2) {
return t1
}
@ -436,32 +244,25 @@ func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
// notReferencedIn checks whether the type doesn't appear directly or transitively within the other
// type. This is a standard requirement for type unification, commonly referred to as the "occurs
// check".
func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
if proto.Equal(t, withinType) {
func notReferencedIn(m *mapping, t, withinType *types.Type) bool {
if t.IsExactType(withinType) {
return false
}
withinKind := kindOf(withinType)
withinKind := withinType.Kind()
switch withinKind {
case kindTypeParam:
case types.TypeParamKind:
wtSub, found := m.find(withinType)
if !found {
return true
}
return notReferencedIn(m, t, wtSub)
case kindAbstract:
for _, pt := range withinType.GetAbstractType().GetParameterTypes() {
case types.OpaqueKind, types.ListKind, types.MapKind:
for _, pt := range withinType.Parameters() {
if !notReferencedIn(m, t, pt) {
return false
}
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().GetElemType())
case kindMap:
mt := withinType.GetMapType()
return notReferencedIn(m, t, mt.GetKeyType()) && notReferencedIn(m, t, mt.GetValueType())
case kindWrapper:
return notReferencedIn(m, t, decls.NewPrimitiveType(withinType.GetWrapper()))
default:
return true
}
@ -469,39 +270,25 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
// substitute replaces all direct and indirect occurrences of bound type parameters. Unbound type
// parameters are replaced by DYN if typeParamToDyn is true.
func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
func substitute(m *mapping, t *types.Type, typeParamToDyn bool) *types.Type {
if tSub, found := m.find(t); found {
return substitute(m, tSub, typeParamToDyn)
}
kind := kindOf(t)
if typeParamToDyn && kind == kindTypeParam {
return decls.Dyn
kind := t.Kind()
if typeParamToDyn && kind == types.TypeParamKind {
return types.DynType
}
switch kind {
case kindAbstract:
at := t.GetAbstractType()
params := make([]*exprpb.Type, len(at.GetParameterTypes()))
for i, p := range at.GetParameterTypes() {
params[i] = substitute(m, p, typeParamToDyn)
}
return decls.NewAbstractType(at.GetName(), params...)
case kindFunction:
fn := t.GetFunction()
rt := substitute(m, fn.ResultType, typeParamToDyn)
args := make([]*exprpb.Type, len(fn.GetArgTypes()))
for i, a := range fn.ArgTypes {
args[i] = substitute(m, a, typeParamToDyn)
}
return decls.NewFunctionType(rt, args...)
case kindList:
return decls.NewListType(substitute(m, t.GetListType().GetElemType(), typeParamToDyn))
case kindMap:
mt := t.GetMapType()
return decls.NewMapType(substitute(m, mt.GetKeyType(), typeParamToDyn),
substitute(m, mt.GetValueType(), typeParamToDyn))
case kindType:
if t.GetType() != nil {
return decls.NewTypeType(substitute(m, t.GetType(), typeParamToDyn))
case types.OpaqueKind:
return types.NewOpaqueType(t.TypeName(), substituteParams(m, t.Parameters(), typeParamToDyn)...)
case types.ListKind:
return types.NewListType(substitute(m, t.Parameters()[0], typeParamToDyn))
case types.MapKind:
return types.NewMapType(substitute(m, t.Parameters()[0], typeParamToDyn),
substitute(m, t.Parameters()[1], typeParamToDyn))
case types.TypeKind:
if len(t.Parameters()) > 0 {
return types.NewTypeTypeWithParam(substitute(m, t.Parameters()[0], typeParamToDyn))
}
return t
default:
@ -509,21 +296,14 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
}
}
func typeKey(t *exprpb.Type) string {
return FormatCheckedType(t)
func substituteParams(m *mapping, typeParams []*types.Type, typeParamToDyn bool) []*types.Type {
subParams := make([]*types.Type, len(typeParams))
for i, tp := range typeParams {
subParams[i] = substitute(m, tp, typeParamToDyn)
}
return subParams
}
// flattenFunctionTypes takes a function with arg types T1, T2, ..., TN and result type TR
// and returns a slice containing {T1, T2, ..., TN, TR}.
func flattenFunctionTypes(f *exprpb.Type_FunctionType) []*exprpb.Type {
argTypes := f.GetArgTypes()
if len(argTypes) == 0 {
return []*exprpb.Type{f.GetResultType()}
}
flattend := make([]*exprpb.Type, len(argTypes)+1, len(argTypes)+1)
for i, at := range argTypes {
flattend[i] = at
}
flattend[len(argTypes)] = f.GetResultType()
return flattend
func newFunctionType(resultType *types.Type, argTypes ...*types.Type) *types.Type {
return types.NewOpaqueType("function", append([]*types.Type{resultType}, argTypes...)...)
}