// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ext import ( "fmt" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/parser" ) const ( mapInsert = "cel.@mapInsert" mapInsertOverloadMap = "@mapInsert_map_map" mapInsertOverloadKeyValue = "@mapInsert_map_key_value" ) // TwoVarComprehensions introduces support for two-variable comprehensions. // // The two-variable form of comprehensions looks similar to the one-variable counterparts. // Where possible, the same macro names were used and additional macro signatures added. // The notable distinction for two-variable comprehensions is the introduction of // `transformList`, `transformMap`, and `transformMapEntry` support for list and map types // rather than the more traditional `map` and `filter` macros. // // # All // // Comprehension which tests whether all elements in the list or map satisfy a given // predicate. The `all` macro evaluates in a manner consistent with logical AND and will // short-circuit when encountering a `false` value. // // .all(indexVar, valueVar, ) -> bool // .all(keyVar, valueVar, ) -> bool // // Examples: // // [1, 2, 3].all(i, j, i < j) // returns true // {'hello': 'world', 'taco': 'taco'}.all(k, v, k != v) // returns false // // // Combines two-variable comprehension with single variable // {'h': ['hello', 'hi'], 'j': ['joke', 'jog']} // .all(k, vals, vals.all(v, v.startsWith(k))) // returns true // // # Exists // // Comprehension which tests whether any element in a list or map exists which satisfies // a given predicate. The `exists` macro evaluates in a manner consistent with logical OR // and will short-circuit when encountering a `true` value. // // .exists(indexVar, valueVar, ) -> bool // .exists(keyVar, valueVar, ) -> bool // // Examples: // // {'greeting': 'hello', 'farewell': 'goodbye'} // .exists(k, v, k.startsWith('good') || v.endsWith('bye')) // returns true // [1, 2, 4, 8, 16].exists(i, v, v == 1024 && i == 10) // returns false // // # ExistsOne // // Comprehension which tests whether exactly one element in a list or map exists which // satisfies a given predicate expression. This comprehension does not short-circuit in // keeping with the one-variable exists one macro semantics. // // .existsOne(indexVar, valueVar, ) // .existsOne(keyVar, valueVar, ) // // This macro may also be used with the `exists_one` function name, for compatibility // with the one-variable macro of the same name. // // Examples: // // [1, 2, 1, 3, 1, 4].existsOne(i, v, i == 1 || v == 1) // returns false // [1, 1, 2, 2, 3, 3].existsOne(i, v, i == 2 && v == 2) // returns true // {'i': 0, 'j': 1, 'k': 2}.existsOne(i, v, i == 'l' || v == 1) // returns true // // # TransformList // // Comprehension which converts a map or a list into a list value. The output expression // of the comprehension determines the contents of the output list. Elements in the list // may optionally be filtered according to a predicate expression, where elements that // satisfy the predicate are transformed. // // .transformList(indexVar, valueVar, ) // .transformList(indexVar, valueVar, , ) // .transformList(keyVar, valueVar, ) // .transformList(keyVar, valueVar, , ) // // Examples: // // [1, 2, 3].transformList(indexVar, valueVar, // (indexVar * valueVar) + valueVar) // returns [1, 4, 9] // [1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0 // (indexVar * valueVar) + valueVar) // returns [1, 9] // {'greeting': 'hello', 'farewell': 'goodbye'} // .transformList(k, _, k) // returns ['greeting', 'farewell'] // {'greeting': 'hello', 'farewell': 'goodbye'} // .transformList(_, v, v) // returns ['hello', 'goodbye'] // // # TransformMap // // Comprehension which converts a map or a list into a map value. The output expression // of the comprehension determines the value of the output map entry; however, the key // remains fixed. Elements in the map may optionally be filtered according to a predicate // expression, where elements that satisfy the predicate are transformed. // // .transformMap(indexVar, valueVar, ) // .transformMap(indexVar, valueVar, , ) // .transformMap(keyVar, valueVar, ) // .transformMap(keyVar, valueVar, , ) // // Examples: // // [1, 2, 3].transformMap(indexVar, valueVar, // (indexVar * valueVar) + valueVar) // returns {0: 1, 1: 4, 2: 9} // [1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0 // (indexVar * valueVar) + valueVar) // returns {0: 1, 2: 9} // {'greeting': 'hello'}.transformMap(k, v, v + '!') // returns {'greeting': 'hello!'} // // # TransformMapEntry // // Comprehension which converts a map or a list into a map value; however, this transform // expects the entry expression be a map literal. If the tranform produces an entry which // duplicates a key in the target map, the comprehension will error. Note, that key // equality is determined using CEL equality which asserts that numeric values which are // equal, even if they don't have the same type will cause a key collision. // // Elements in the map may optionally be filtered according to a predicate expression, where // elements that satisfy the predicate are transformed. // // .transformMap(indexVar, valueVar, ) // .transformMap(indexVar, valueVar, , ) // .transformMap(keyVar, valueVar, ) // .transformMap(keyVar, valueVar, , ) // // Examples: // // // returns {'hello': 'greeting'} // {'greeting': 'hello'}.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // // reverse lookup, require all values in list be unique // [1, 2, 3].transformMapEntry(indexVar, valueVar, {valueVar: indexVar}) // // {'greeting': 'aloha', 'farewell': 'aloha'} // .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key func TwoVarComprehensions() cel.EnvOption { return cel.Lib(compreV2Lib{}) } type compreV2Lib struct{} // LibraryName implements that SingletonLibrary interface method. func (compreV2Lib) LibraryName() string { return "cel.lib.ext.comprev2" } // CompileOptions implements the cel.Library interface method. func (compreV2Lib) CompileOptions() []cel.EnvOption { kType := cel.TypeParamType("K") vType := cel.TypeParamType("V") mapKVType := cel.MapType(kType, vType) opts := []cel.EnvOption{ cel.Macros( cel.ReceiverMacro("all", 3, quantifierAll), cel.ReceiverMacro("exists", 3, quantifierExists), cel.ReceiverMacro("existsOne", 3, quantifierExistsOne), cel.ReceiverMacro("exists_one", 3, quantifierExistsOne), cel.ReceiverMacro("transformList", 3, transformList), cel.ReceiverMacro("transformList", 4, transformList), cel.ReceiverMacro("transformMap", 3, transformMap), cel.ReceiverMacro("transformMap", 4, transformMap), cel.ReceiverMacro("transformMapEntry", 3, transformMapEntry), cel.ReceiverMacro("transformMapEntry", 4, transformMapEntry), ), cel.Function(mapInsert, cel.Overload(mapInsertOverloadKeyValue, []*cel.Type{mapKVType, kType, vType}, mapKVType, cel.FunctionBinding(func(args ...ref.Val) ref.Val { m := args[0].(traits.Mapper) k := args[1] v := args[2] return types.InsertMapKeyValue(m, k, v) })), cel.Overload(mapInsertOverloadMap, []*cel.Type{mapKVType, mapKVType}, mapKVType, cel.BinaryBinding(func(targetMap, updateMap ref.Val) ref.Val { tm := targetMap.(traits.Mapper) um := updateMap.(traits.Mapper) umIt := um.Iterator() for umIt.HasNext() == types.True { k := umIt.Next() updateOrErr := types.InsertMapKeyValue(tm, k, um.Get(k)) if types.IsError(updateOrErr) { return updateOrErr } tm = updateOrErr.(traits.Mapper) } return tm })), ), } return opts } // ProgramOptions implements the cel.Library interface method func (compreV2Lib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewLiteral(types.True), /*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()), /*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]), /*result=*/ mef.NewAccuIdent(), ), nil } func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewLiteral(types.False), /*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())), /*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]), /*result=*/ mef.NewAccuIdent(), ), nil } func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewLiteral(types.Int(0)), /*condition=*/ mef.NewLiteral(types.True), /*step=*/ mef.NewCall(operators.Conditional, args[2], mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))), mef.NewAccuIdent()), /*result=*/ mef.NewCall(operators.Equals, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))), ), nil } func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } var transform ast.Expr var filter ast.Expr if len(args) == 4 { filter = args[2] transform = args[3] } else { filter = nil transform = args[2] } // __result__ = __result__ + [transform] step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform)) if filter != nil { // __result__ = (filter) ? __result__ + [transform] : __result__ step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewList(), /*condition=*/ mef.NewLiteral(types.True), step, /*result=*/ mef.NewAccuIdent(), ), nil } func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } var transform ast.Expr var filter ast.Expr if len(args) == 4 { filter = args[2] transform = args[3] } else { filter = nil transform = args[2] } // __result__ = cel.@mapInsert(__result__, iterVar1, transform) step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform) if filter != nil { // __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__ step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewMap(), /*condition=*/ mef.NewLiteral(types.True), step, /*result=*/ mef.NewAccuIdent(), ), nil } func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } var transform ast.Expr var filter ast.Expr if len(args) == 4 { filter = args[2] transform = args[3] } else { filter = nil transform = args[2] } // __result__ = cel.@mapInsert(__result__, transform) step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform) if filter != nil { // __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__ step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) } return mef.NewComprehensionTwoVar( target, iterVar1, iterVar2, parser.AccumulatorName, /*accuInit=*/ mef.NewMap(), /*condition=*/ mef.NewLiteral(types.True), step, /*result=*/ mef.NewAccuIdent(), ), nil } func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, string, *cel.Error) { iterVar1, err := extractIterVar(mef, arg0) if err != nil { return "", "", err } iterVar2, err := extractIterVar(mef, arg1) if err != nil { return "", "", err } if iterVar1 == iterVar2 { return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1)) } if iterVar1 == parser.AccumulatorName { return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable") } if iterVar2 == parser.AccumulatorName { return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable") } return iterVar1, iterVar2, nil } func extractIterVar(mef cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) { iterVar, found := extractIdent(target) if !found { return "", mef.NewError(target.ID(), "argument must be a simple name") } return iterVar, nil }