1
0
mirror of https://github.com/ceph/ceph-csi.git synced 2025-01-22 20:59:30 +00:00
ceph-csi/vendor/github.com/google/cel-go/cel/validator.go

376 lines
12 KiB
Go
Raw Normal View History

// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"reflect"
"regexp"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
)
const (
homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous"
// HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure
// the set of function names which are exempt from homogeneous type checks. The expected type
// is a string list of function names.
//
// As an example, the `<string>.format([args])` call expects the input arguments list to be
// comprised of a variety of types which correspond to the types expected by the format control
// clauses; however, all other uses of a mixed element type list, would be unexpected.
HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt"
)
// ASTValidators configures a set of ASTValidator instances into the target environment.
//
// Validators are applied in the order in which the are specified and are treated as singletons.
// The same ASTValidator with a given name will not be applied more than once.
func ASTValidators(validators ...ASTValidator) EnvOption {
return func(e *Env) (*Env, error) {
for _, v := range validators {
if !e.HasValidator(v.Name()) {
e.validators = append(e.validators, v)
}
}
return e, nil
}
}
// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment.
//
// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be
// reported to the caller.
type ASTValidator interface {
// Name returns the name of the validator. Names must be unique.
Name() string
// Validate validates a given Ast within an Environment and collects a set of potential issues.
//
// The ValidatorConfig is generated from the set of ASTValidatorConfigurer instances prior to
// the invocation of the Validate call. The expectation is that the validator configuration
// is created in sequence and immutable once provided to the Validate call.
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
type ValidatorConfig interface {
GetOrDefault(name string, value any) any
}
// MutableValidatorConfig provides mutation methods for querying and updating validator configuration
// settings.
type MutableValidatorConfig interface {
ValidatorConfig
Set(name string, value any) error
}
// ASTValidatorConfigurer indicates that this object, currently expected to be an ASTValidator,
// participates in validator configuration settings.
//
// This interface may be split from the expectation of being an ASTValidator instance in the future.
type ASTValidatorConfigurer interface {
Configure(MutableValidatorConfig) error
}
// validatorConfig implements the ValidatorConfig and MutableValidatorConfig interfaces.
type validatorConfig struct {
data map[string]any
}
// newValidatorConfig initializes the validator config with default values for core CEL validators.
func newValidatorConfig() *validatorConfig {
return &validatorConfig{
data: map[string]any{
HomogeneousAggregateLiteralExemptFunctions: []string{},
},
}
}
// GetOrDefault returns the configured value for the name, if present, else the input default value.
//
// Note, the type-agreement between the input default and configured value is not checked on read.
func (config *validatorConfig) GetOrDefault(name string, value any) any {
v, found := config.data[name]
if !found {
return value
}
return v
}
// Set configures a validator option with the given name and value.
//
// If the value had previously been set, the new value must have the same reflection type as the old one,
// or the call will error.
func (config *validatorConfig) Set(name string, value any) error {
v, found := config.data[name]
if found && reflect.TypeOf(v) != reflect.TypeOf(value) {
return fmt.Errorf("incompatible configuration type for %s, got %T, wanted %T", name, value, v)
}
config.data[name] = value
return nil
}
// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors.
//
// - Validate duration and timestamp literals
// - Ensure regex strings are valid
// - Disable mixed type list and map literals
func ExtendedValidations() EnvOption {
return ASTValidators(
ValidateDurationLiterals(),
ValidateTimestampLiterals(),
ValidateRegexLiterals(),
ValidateHomogeneousAggregateLiterals(),
)
}
// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check.
func ValidateDurationLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall)
}
// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check.
func ValidateTimestampLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall)
}
// ValidateRegexLiterals ensures that regex patterns are validated after type-check.
func ValidateRegexLiterals() ASTValidator {
return newFormatValidator(overloads.Matches, 0, compileRegex)
}
// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e.
// no mixed list element types or mixed map key or map value types.
//
// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all
// literals which occur within string format calls.
func ValidateHomogeneousAggregateLiterals() ASTValidator {
return homogeneousAggregateLiteralValidator{}
}
// ValidateComprehensionNestingLimit ensures that comprehension nesting does not exceed the specified limit.
//
// This validator can be useful for preventing arbitrarily nested comprehensions which can take high polynomial
// time to complete.
//
// Note, this limit does not apply to comprehensions with an empty iteration range, as these comprehensions have
// no actual looping cost. The cel.bind() utilizes the comprehension structure to perform local variable
// assignments and supplies an empty iteration range, so they won't count against the nesting limit either.
func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.Expr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
funcName: funcName,
check: check,
argNum: argNum,
}
}
type formatValidator struct {
funcName string
argNum int
check argChecker
}
// Name returns the unique name of this function format validator.
func (v formatValidator) Name() string {
return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
}
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
if len(callArgs) <= v.argNum {
continue
}
litArg := callArgs[v.argNum]
if litArg.Kind() != ast.LiteralKind {
continue
}
if err := v.check(e, call, litArg); err != nil {
iss.ReportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName)
}
}
}
func evalCall(env *Env, call, arg ast.Expr) error {
ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
prg, err := env.Program(ast)
if err != nil {
return err
}
_, _, err = prg.Eval(NoVars())
return err
}
func compileRegex(_ *Env, _, arg ast.Expr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
}
type homogeneousAggregateLiteralValidator struct{}
// Name returns the unique name of the homogeneous type validator.
func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
continue
}
l := listExpr.AsList()
elements := l.Elements()
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := a.GetType(e.ID())
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
if elemType == nil {
elemType = et
continue
}
if !elemType.IsEquivalentType(et) {
v.typeMismatch(iss, e.ID(), elemType, et)
break
}
}
}
mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind))
for _, mapExpr := range mapExprs {
if inExemptFunction(mapExpr, exemptedFunctions) {
continue
}
m := mapExpr.AsMap()
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
mapEntry := e.AsMapEntry()
key, val := mapEntry.Key(), mapEntry.Value()
kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
if mapEntry.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
keyType, valType = kt, vt
continue
}
if !keyType.IsEquivalentType(kt) {
v.typeMismatch(iss, key.ID(), keyType, kt)
}
if !valType.IsEquivalentType(vt) {
v.typeMismatch(iss, val.ID(), valType, vt)
}
}
}
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
if exempt == fnName {
return true
}
}
}
parent, found = parent.Parent()
}
return false
}
func isOptionalIndex(i int, optIndices []int32) bool {
for _, optInd := range optIndices {
if i == int(optInd) {
return true
}
}
return false
}
func (homogeneousAggregateLiteralValidator) typeMismatch(iss *Issues, id int64, expected, actual *Type) {
iss.ReportErrorAtID(id, "expected type '%s' but found '%s'", FormatCELType(expected), FormatCELType(actual))
}
type nestingLimitValidator struct {
limit int
}
func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return
}
for _, comp := range comprehensions {
count := 0
e := comp
hasParent := true
for hasParent {
// When the expression is not a comprehension, continue to the next ancestor.
if e.Kind() != ast.ComprehensionKind {
e, hasParent = e.Parent()
continue
}
// When the comprehension has an empty range, continue to the next ancestor
// as this comprehension does not have any associated cost.
iterRange := e.AsComprehension().IterRange()
if iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 {
e, hasParent = e.Parent()
continue
}
// Otherwise check the nesting limit.
count++
if count > v.limit {
iss.ReportErrorAtID(comp.ID(), "comprehension exceeds nesting limit")
break
}
e, hasParent = e.Parent()
}
}
}