ceph-csi/vendor/github.com/google/cel-go/checker/checker.go

697 lines
22 KiB
Go
Raw Normal View History

// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package checker defines functions to type-checked a parsed expression
// against a set of identifier and function declarations.
package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type checker struct {
*ast.AST
ast.ExprFactory
env *Env
errors *typeErrors
mappings *mapping
freeTypeVarCounter int
}
// Check performs type checking, giving a typed AST.
//
// The input is a parsed AST and an env which encapsulates type binding of variables,
// declarations of built-in functions, descriptions of protocol buffers, and a registry for
// errors.
//
// Returns a type-checked AST, which might not be usable if there are errors in the error
// registry.
func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
typeMap := make(map[int64]*types.Type)
refMap := make(map[int64]*ast.ReferenceInfo)
c := checker{
AST: ast.NewCheckedAST(parsed, typeMap, refMap),
ExprFactory: ast.NewExprFactory(),
env: env,
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
}
c.check(c.Expr())
// Walk over the final type map substituting any type parameters either by their bound value
// or by DYN.
for id, t := range c.TypeMap() {
c.SetType(id, substitute(c.mappings, t, true))
}
return c.AST, errs
}
func (c *checker) check(e ast.Expr) {
if e == nil {
return
}
switch e.Kind() {
case ast.LiteralKind:
literal := ref.Val(e.AsLiteral())
switch literal.Type() {
case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
types.NullType, types.StringType, types.UintType:
c.setType(e, literal.Type().(*types.Type))
default:
c.errors.unexpectedASTType(e.ID(), c.location(e), "literal", literal.Type().TypeName())
}
case ast.IdentKind:
c.checkIdent(e)
case ast.SelectKind:
c.checkSelect(e)
case ast.CallKind:
c.checkCall(e)
case ast.ListKind:
c.checkCreateList(e)
case ast.MapKind:
c.checkCreateMap(e)
case ast.StructKind:
c.checkCreateStruct(e)
case ast.ComprehensionKind:
c.checkComprehension(e)
default:
c.errors.unexpectedASTType(e.ID(), c.location(e), "unspecified", reflect.TypeOf(e).Name())
}
}
func (c *checker) checkIdent(e ast.Expr) {
identName := e.AsIdent()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identName); ident != nil {
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), identName)
}
func (c *checker) checkSelect(e ast.Expr) {
sel := e.AsSelect()
// Before traversing down the tree, try to interpret as qualified name.
qname, found := containers.ToQualifiedName(e)
if found {
ident := c.env.LookupIdent(qname)
if ident != nil {
// We don't check for a TestOnly expression here since the `found` result is
// always going to be false for TestOnly expressions.
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
e.SetKindCase(c.NewIdent(e.ID(), ident.Name()))
return
}
}
resultType := c.checkSelectField(e, sel.Operand(), sel.FieldName(), false)
if sel.IsTestOnly() {
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e ast.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.AsCall()
operand := call.Args()[0]
field := call.Args()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.notAnOptionalFieldSelection(field.ID(), c.location(field), field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand ast.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
// If the target type is 'optional', unwrap it for the sake of this check.
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := types.ErrorType
switch targetType.Kind() {
case types.MapKind:
// Maps yield their value type as the selection result type.
resultType = targetType.Parameters()[1]
case types.StructKind:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(e.ID(), messageType.TypeName(), field); found {
resultType = fieldType
}
case types.TypeParamKind:
// Set the operand type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(types.DynType, targetType)
// Also, set the result type to DYN.
resultType = types.DynType
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(e.ID(), c.location(e), targetType)
}
resultType = types.DynType
}
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return types.NewOptionalType(resultType)
}
return resultType
}
func (c *checker) checkCall(e ast.Expr) {
// Note: similar logic exists within the `interpreter/planner.go`. If making changes here
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.AsCall()
fnName := call.FunctionName()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.Args()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
// Regular static call with simple name.
if !call.IsMemberFunction() {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
// Check to see whether the overload resolves.
c.resolveOverloadOrError(e, fn, nil, args)
return
}
// If a receiver 'target' is present, it may either be a receiver function, or a namespaced
// function, but not both. Given a.b.c() either a.b.c is a function or c is a function with
// target a.b.
//
// Check whether the target is a namespaced function name.
target := call.Target()
qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target)
if maybeQualified {
maybeQualifiedName := qualifiedPrefix + "." + fnName
fn := c.env.LookupFunction(maybeQualifiedName)
if fn != nil {
// The function name is namespaced and so preserving the target operand would
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...))
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
// Regular instance call.
c.check(target)
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
c.resolveOverloadOrError(e, fn, target, args)
return
}
// Function name not declared, record error.
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
e ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
if resolution == nil {
c.setType(e, types.ErrorType)
return
}
// Overload found.
c.setType(e, resolution.Type)
c.setReference(e, resolution.Reference)
}
func (c *checker) resolveOverload(
call ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) *overloadResolution {
var argTypes []*types.Type
if target != nil {
argTypes = append(argTypes, c.getType(target))
}
for _, arg := range args {
argTypes = append(argTypes, c.getType(arg))
}
var resultType *types.Type
var checkedRef *ast.ReferenceInfo
for _, overload := range fn.OverloadDecls() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.ID()) {
continue
}
// Ensure the call style for the overload matches.
if (target == nil && overload.IsMemberFunction()) ||
(target != nil && !overload.IsMemberFunction()) {
// not a compatible call style.
continue
}
// Alternative type-checking behavior when the logical operators are compacted into
// variadic AST representations.
if fn.Name() == operators.LogicalAnd || fn.Name() == operators.LogicalOr {
checkedRef = ast.NewFunctionReference(overload.ID())
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].ID(),
c.locationByID(args[i].ID()),
types.BoolType,
argType)
resultType = types.ErrorType
}
}
if isError(resultType) {
return nil
}
return newResolution(checkedRef, types.BoolType)
}
overloadType := newFunctionType(overload.ResultType(), overload.ArgTypes()...)
typeParams := overload.TypeParams()
if len(typeParams) != 0 {
// Instantiate overload's type with fresh type variables.
substitutions := newMapping()
for _, typePar := range typeParams {
substitutions.add(types.NewTypeParamType(typePar), c.newTypeVar())
}
overloadType = substitute(substitutions, overloadType, false)
}
candidateArgTypes := overloadType.Parameters()[1:]
if c.isAssignableList(argTypes, candidateArgTypes) {
if checkedRef == nil {
checkedRef = ast.NewFunctionReference(overload.ID())
} else {
checkedRef.AddOverload(overload.ID())
}
// First matching overload, determines result type.
fnResultType := substitute(c.mappings, overloadType.Parameters()[0], false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !fnResultType.IsExactType(resultType) {
resultType = types.DynType
}
}
}
if resultType == nil {
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(call.ID(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
return newResolution(checkedRef, resultType)
}
func (c *checker) checkCreateList(e ast.Expr) {
create := e.AsList()
var elemsType *types.Type
optionalIndices := create.OptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.Elements() {
c.check(e)
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(e.ID(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(e, elemsType, elemType)
}
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemsType = c.newTypeVar()
}
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateMap(e ast.Expr) {
mapVal := e.AsMap()
var mapKeyType *types.Type
var mapValueType *types.Type
for _, e := range mapVal.Entries() {
entry := e.AsMapEntry()
key := entry.Key()
c.check(key)
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
val := entry.Value()
c.check(val)
valType := c.getType(val)
if entry.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(val.ID(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(val, mapValueType, valType)
}
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateStruct(e ast.Expr) {
msgVal := e.AsStruct()
// Determine the type of the message.
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.TypeName())
if ident == nil {
c.errors.undeclaredReference(
e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
typeName := ident.Name()
if msgVal.TypeName() != typeName {
e.SetKindCase(c.NewStruct(e.ID(), typeName, msgVal.Fields()))
msgVal = e.AsStruct()
}
c.setReference(e, ast.NewIdentReference(typeName, nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.ID(), c.location(e), ident.Type().DeclaredTypeName())
} else {
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
// In this context, the type is being instantiated by its protobuf name which
// is not ideal or recommended, but some users expect this to work.
if isWellKnownType(resultType) {
typeName = getWellKnownTypeName(resultType)
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.ID(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
}
c.setType(e, resultType)
// Check the field initializers.
for _, f := range msgVal.Fields() {
field := f.AsStructField()
fieldName := field.Name()
value := field.Value()
c.check(value)
fieldType := types.ErrorType
ft, found := c.lookupFieldType(f.ID(), typeName, fieldName)
if found {
fieldType = ft
}
valType := c.getType(value)
if field.IsOptional() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(value.ID(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(f.ID(), c.locationByID(f.ID()), fieldName, fieldType, valType)
}
}
}
func (c *checker) checkComprehension(e ast.Expr) {
comp := e.AsComprehension()
c.check(comp.IterRange())
c.check(comp.AccuInit())
accuType := c.getType(comp.AccuInit())
rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false)
var varType *types.Type
switch rangeType.Kind() {
case types.ListKind:
varType = rangeType.Parameters()[0]
case types.MapKind:
// Ranges over the keys.
varType = rangeType.Parameters()[0]
case types.DynKind, types.ErrorKind, types.TypeParamKind:
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(types.DynType, rangeType)
// Set the range iteration variable to type DYN as well.
varType = types.DynType
default:
c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.LoopCondition())
c.assertType(comp.LoopCondition(), types.BoolType)
c.check(comp.LoopStep())
c.assertType(comp.LoopStep(), accuType)
// Exit the loop's block scope before checking the result.
c.env = c.env.exitScope()
c.check(comp.Result())
// Exit the comprehension scope.
c.env = c.env.exitScope()
c.setType(e, substitute(c.mappings, c.getType(comp.Result()), false))
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(e ast.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
if c.isAssignable(previous, current) {
return mostGeneral(previous, current)
}
if c.dynAggregateLiteralElementTypesEnabled() {
return types.DynType
}
c.errors.typeMismatch(e.ID(), c.location(e), previous, current)
return types.ErrorType
}
func (c *checker) dynAggregateLiteralElementTypesEnabled() bool {
return c.env.aggLitElemType == dynElementType
}
func (c *checker) newTypeVar() *types.Type {
id := c.freeTypeVarCounter
c.freeTypeVarCounter++
return types.NewTypeParamType(fmt.Sprintf("_var%d", id))
}
func (c *checker) isAssignable(t1, t2 *types.Type) bool {
subs := isAssignable(c.mappings, t1, t2)
if subs != nil {
c.mappings = subs
return true
}
return false
}
func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
subs := isAssignableList(c.mappings, l1, l2)
if subs != nil {
c.mappings = subs
return true
}
return false
}
func maybeUnwrapString(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.LiteralKind:
literal := e.AsLiteral()
switch v := literal.(type) {
case types.String:
return string(v), true
}
}
return "", false
}
func (c *checker) setType(e ast.Expr, t *types.Type) {
if old, found := c.TypeMap()[e.ID()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.ID(), c.location(e), e, old, t)
return
}
c.SetType(e.ID(), t)
}
func (c *checker) getType(e ast.Expr) *types.Type {
return c.TypeMap()[e.ID()]
}
func (c *checker) setReference(e ast.Expr, r *ast.ReferenceInfo) {
if old, found := c.ReferenceMap()[e.ID()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.ID(), c.location(e), e, old, r)
return
}
c.SetReference(e.ID(), r)
}
func (c *checker) assertType(e ast.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(e.ID(), c.location(e), t, c.getType(e))
}
}
type overloadResolution struct {
Type *types.Type
Reference *ast.ReferenceInfo
}
func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
return &overloadResolution{
Reference: r,
Type: t,
}
}
func (c *checker) location(e ast.Expr) common.Location {
return c.locationByID(e.ID())
}
func (c *checker) locationByID(id int64) common.Location {
return c.SourceInfo().GetStartLocation(id)
}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {
if _, found := c.env.provider.FindStructType(structType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(exprID, c.locationByID(exprID), structType)
return nil, false
}
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
return ft.Type, found
}
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
return nil, false
}
func isWellKnownType(t *types.Type) bool {
switch t.Kind() {
case types.AnyKind, types.TimestampKind, types.DurationKind, types.DynKind, types.NullTypeKind:
return true
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind:
return t.IsAssignableType(types.NullType)
case types.ListKind:
return t.Parameters()[0] == types.DynType
case types.MapKind:
return t.Parameters()[0] == types.StringType && t.Parameters()[1] == types.DynType
}
return false
}
func getWellKnownTypeName(t *types.Type) string {
if name, found := wellKnownTypes[t.Kind()]; found {
return name
}
return ""
}
var (
wellKnownTypes = map[types.Kind]string{
types.AnyKind: "google.protobuf.Any",
types.BoolKind: "google.protobuf.BoolValue",
types.BytesKind: "google.protobuf.BytesValue",
types.DoubleKind: "google.protobuf.DoubleValue",
types.DurationKind: "google.protobuf.Duration",
types.DynKind: "google.protobuf.Value",
types.IntKind: "google.protobuf.Int64Value",
types.ListKind: "google.protobuf.ListValue",
types.NullTypeKind: "google.protobuf.NullValue",
types.MapKind: "google.protobuf.Struct",
types.StringKind: "google.protobuf.StringValue",
types.TimestampKind: "google.protobuf.Timestamp",
types.UintKind: "google.protobuf.UInt64Value",
}
)