ceph-csi/vendor/github.com/google/cel-go/common/ast/conversion.go

633 lines
17 KiB
Go
Raw Permalink Normal View History

// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"fmt"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// ToProto converts an AST to a CheckedExpr protobouf.
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
for id, ref := range ast.ReferenceMap() {
r, err := ReferenceInfoToProto(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap()))
for id, typ := range ast.TypeMap() {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
e, err := ExprToProto(ast.Expr())
if err != nil {
return nil, err
}
info, err := SourceInfoToProto(ast.SourceInfo())
if err != nil {
return nil, err
}
return &exprpb.CheckedExpr{
Expr: e,
SourceInfo: info,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// ToAST converts a CheckedExpr protobuf to an AST instance.
func ToAST(checked *exprpb.CheckedExpr) (*AST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ProtoToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
info, err := ProtoToSourceInfo(checked.GetSourceInfo())
if err != nil {
return nil, err
}
root, err := ProtoToExpr(checked.GetExpr())
if err != nil {
return nil, err
}
ast := NewCheckedAST(NewAST(root, info), typeMap, refMap)
return ast, nil
}
// ProtoToExpr converts a protobuf Expr value to an ast.Expr value.
func ProtoToExpr(e *exprpb.Expr) (Expr, error) {
factory := NewExprFactory()
return exprInternal(factory, e)
}
// ProtoToEntryExpr converts a protobuf struct/map entry to an ast.EntryExpr
func ProtoToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
factory := NewExprFactory()
switch e.GetKeyKind().(type) {
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
return exprStructField(factory, e.GetId(), e)
case *exprpb.Expr_CreateStruct_Entry_MapKey:
return exprMapEntry(factory, e.GetId(), e)
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) {
id := e.GetId()
switch e.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
return exprCall(factory, id, e.GetCallExpr())
case *exprpb.Expr_ComprehensionExpr:
return exprComprehension(factory, id, e.GetComprehensionExpr())
case *exprpb.Expr_ConstExpr:
return exprLiteral(factory, id, e.GetConstExpr())
case *exprpb.Expr_IdentExpr:
return exprIdent(factory, id, e.GetIdentExpr())
case *exprpb.Expr_ListExpr:
return exprList(factory, id, e.GetListExpr())
case *exprpb.Expr_SelectExpr:
return exprSelect(factory, id, e.GetSelectExpr())
case *exprpb.Expr_StructExpr:
s := e.GetStructExpr()
if s.GetMessageName() != "" {
return exprStruct(factory, id, s)
}
return exprMap(factory, id, s)
}
return factory.NewUnspecifiedExpr(id), nil
}
func exprCall(factory ExprFactory, id int64, call *exprpb.Expr_Call) (Expr, error) {
var err error
args := make([]Expr, len(call.GetArgs()))
for i, a := range call.GetArgs() {
args[i], err = exprInternal(factory, a)
if err != nil {
return nil, err
}
}
if call.GetTarget() == nil {
return factory.NewCall(id, call.GetFunction(), args...), nil
}
target, err := exprInternal(factory, call.GetTarget())
if err != nil {
return nil, err
}
return factory.NewMemberCall(id, call.GetFunction(), target, args...), nil
}
func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehension) (Expr, error) {
iterRange, err := exprInternal(factory, comp.GetIterRange())
if err != nil {
return nil, err
}
accuInit, err := exprInternal(factory, comp.GetAccuInit())
if err != nil {
return nil, err
}
loopCond, err := exprInternal(factory, comp.GetLoopCondition())
if err != nil {
return nil, err
}
loopStep, err := exprInternal(factory, comp.GetLoopStep())
if err != nil {
return nil, err
}
result, err := exprInternal(factory, comp.GetResult())
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
iterRange,
comp.GetIterVar(),
comp.GetAccuVar(),
accuInit,
loopCond,
loopStep,
result), nil
}
func exprLiteral(factory ExprFactory, id int64, c *exprpb.Constant) (Expr, error) {
val, err := ConstantToVal(c)
if err != nil {
return nil, err
}
return factory.NewLiteral(id, val), nil
}
func exprIdent(factory ExprFactory, id int64, i *exprpb.Expr_Ident) (Expr, error) {
return factory.NewIdent(id, i.GetName()), nil
}
func exprList(factory ExprFactory, id int64, l *exprpb.Expr_CreateList) (Expr, error) {
elems := make([]Expr, len(l.GetElements()))
for i, e := range l.GetElements() {
elem, err := exprInternal(factory, e)
if err != nil {
return nil, err
}
elems[i] = elem
}
return factory.NewList(id, elems, l.GetOptionalIndices()), nil
}
func exprMap(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
entries := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, entry := range s.GetEntries() {
entries[i], err = exprMapEntry(factory, entry.GetId(), entry)
if err != nil {
return nil, err
}
}
return factory.NewMap(id, entries), nil
}
func exprMapEntry(factory ExprFactory, id int64, e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
k, err := exprInternal(factory, e.GetMapKey())
if err != nil {
return nil, err
}
v, err := exprInternal(factory, e.GetValue())
if err != nil {
return nil, err
}
return factory.NewMapEntry(id, k, v, e.GetOptionalEntry()), nil
}
func exprSelect(factory ExprFactory, id int64, s *exprpb.Expr_Select) (Expr, error) {
op, err := exprInternal(factory, s.GetOperand())
if err != nil {
return nil, err
}
if s.GetTestOnly() {
return factory.NewPresenceTest(id, op, s.GetField()), nil
}
return factory.NewSelect(id, op, s.GetField()), nil
}
func exprStruct(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) {
fields := make([]EntryExpr, len(s.GetEntries()))
var err error
for i, field := range s.GetEntries() {
fields[i], err = exprStructField(factory, field.GetId(), field)
if err != nil {
return nil, err
}
}
return factory.NewStruct(id, s.GetMessageName(), fields), nil
}
func exprStructField(factory ExprFactory, id int64, f *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) {
v, err := exprInternal(factory, f.GetValue())
if err != nil {
return nil, err
}
return factory.NewStructField(id, f.GetFieldKey(), v, f.GetOptionalEntry()), nil
}
// ExprToProto serializes an ast.Expr value to a protobuf Expr representation.
func ExprToProto(e Expr) (*exprpb.Expr, error) {
if e == nil {
return &exprpb.Expr{}, nil
}
switch e.Kind() {
case CallKind:
return protoCall(e.ID(), e.AsCall())
case ComprehensionKind:
return protoComprehension(e.ID(), e.AsComprehension())
case IdentKind:
return protoIdent(e.ID(), e.AsIdent())
case ListKind:
return protoList(e.ID(), e.AsList())
case LiteralKind:
return protoLiteral(e.ID(), e.AsLiteral())
case MapKind:
return protoMap(e.ID(), e.AsMap())
case SelectKind:
return protoSelect(e.ID(), e.AsSelect())
case StructKind:
return protoStruct(e.ID(), e.AsStruct())
case UnspecifiedExprKind:
// Handle the case where a macro reference may be getting translated.
// A nested macro 'pointer' is a non-zero expression id with no kind set.
if e.ID() != 0 {
return &exprpb.Expr{Id: e.ID()}, nil
}
return &exprpb.Expr{}, nil
}
return nil, fmt.Errorf("unsupported expr kind: %v", e)
}
// EntryExprToProto converts an ast.EntryExpr to a protobuf CreateStruct entry
func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) {
switch e.Kind() {
case MapEntryKind:
return protoMapEntry(e.ID(), e.AsMapEntry())
case StructFieldKind:
return protoStructField(e.ID(), e.AsStructField())
case UnspecifiedEntryExprKind:
return &exprpb.Expr_CreateStruct_Entry{}, nil
}
return nil, fmt.Errorf("unsupported expr entry kind: %v", e)
}
func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) {
var err error
var target *exprpb.Expr
if call.IsMemberFunction() {
target, err = ExprToProto(call.Target())
if err != nil {
return nil, err
}
}
callArgs := call.Args()
args := make([]*exprpb.Expr, len(callArgs))
for i, a := range callArgs {
args[i], err = ExprToProto(a)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: call.FunctionName(),
Target: target,
Args: args,
},
},
}, nil
}
func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error) {
iterRange, err := ExprToProto(comp.IterRange())
if err != nil {
return nil, err
}
accuInit, err := ExprToProto(comp.AccuInit())
if err != nil {
return nil, err
}
loopCond, err := ExprToProto(comp.LoopCondition())
if err != nil {
return nil, err
}
loopStep, err := ExprToProto(comp.LoopStep())
if err != nil {
return nil, err
}
result, err := ExprToProto(comp.Result())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
LoopCondition: loopCond,
LoopStep: loopStep,
Result: result,
},
},
}, nil
}
func protoIdent(id int64, name string) (*exprpb.Expr, error) {
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: name,
},
},
}, nil
}
func protoList(id int64, list ListExpr) (*exprpb.Expr, error) {
var err error
elems := make([]*exprpb.Expr, list.Size())
for i, e := range list.Elements() {
elems[i], err = ExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: elems,
OptionalIndices: list.OptionalIndices(),
},
},
}, nil
}
func protoLiteral(id int64, val ref.Val) (*exprpb.Expr, error) {
c, err := ValToConstant(val)
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: c,
},
}, nil
}
func protoMap(id int64, m MapExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(m.Entries()))
var err error
for i, e := range m.Entries() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
Entries: entries,
},
},
}, nil
}
func protoMapEntry(id int64, e MapEntry) (*exprpb.Expr_CreateStruct_Entry, error) {
k, err := ExprToProto(e.Key())
if err != nil {
return nil, err
}
v, err := ExprToProto(e.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: k,
},
Value: v,
OptionalEntry: e.IsOptional(),
}, nil
}
func protoSelect(id int64, s SelectExpr) (*exprpb.Expr, error) {
op, err := ExprToProto(s.Operand())
if err != nil {
return nil, err
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: op,
Field: s.FieldName(),
TestOnly: s.IsTestOnly(),
},
},
}, nil
}
func protoStruct(id int64, s StructExpr) (*exprpb.Expr, error) {
entries := make([]*exprpb.Expr_CreateStruct_Entry, len(s.Fields()))
var err error
for i, e := range s.Fields() {
entries[i], err = EntryExprToProto(e)
if err != nil {
return nil, err
}
}
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
StructExpr: &exprpb.Expr_CreateStruct{
MessageName: s.TypeName(),
Entries: entries,
},
},
}, nil
}
func protoStructField(id int64, f StructField) (*exprpb.Expr_CreateStruct_Entry, error) {
v, err := ExprToProto(f.Value())
if err != nil {
return nil, err
}
return &exprpb.Expr_CreateStruct_Entry{
Id: id,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: f.Name(),
},
Value: v,
OptionalEntry: f.IsOptional(),
}, nil
}
// SourceInfoToProto serializes an ast.SourceInfo value to a protobuf SourceInfo object.
func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
if info == nil {
return &exprpb.SourceInfo{}, nil
}
sourceInfo := &exprpb.SourceInfo{
SyntaxVersion: info.SyntaxVersion(),
Location: info.Description(),
LineOffsets: info.LineOffsets(),
Positions: make(map[int64]int32, len(info.OffsetRanges())),
MacroCalls: make(map[int64]*exprpb.Expr, len(info.MacroCalls())),
}
for id, offset := range info.OffsetRanges() {
sourceInfo.Positions[id] = offset.Start
}
for id, e := range info.MacroCalls() {
call, err := ExprToProto(e)
if err != nil {
return nil, err
}
sourceInfo.MacroCalls[id] = call
}
return sourceInfo, nil
}
// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value.
func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
sourceInfo := &SourceInfo{
syntax: info.GetSyntaxVersion(),
desc: info.GetLocation(),
lines: info.GetLineOffsets(),
offsetRanges: make(map[int64]OffsetRange, len(info.GetPositions())),
macroCalls: make(map[int64]Expr, len(info.GetMacroCalls())),
}
for id, offset := range info.GetPositions() {
sourceInfo.SetOffsetRange(id, OffsetRange{Start: offset, Stop: offset})
}
for id, e := range info.GetMacroCalls() {
call, err := ProtoToExpr(e)
if err != nil {
return nil, err
}
sourceInfo.SetMacroCall(id, call)
}
return sourceInfo, nil
}
// ReferenceInfoToProto converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToProto(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
}
// ProtoToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ProtoToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}