// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package parser import ( "errors" "fmt" "strconv" "strings" "github.com/google/cel-go/common/operators" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // Unparse takes an input expression and source position information and generates a human-readable // expression. // // Note, unparsing an AST will often generate the same expression as was originally parsed, but some // formatting may be lost in translation, notably: // // - All quoted literals are doubled quoted. // - Byte literals are represented as octal escapes (same as Google SQL). // - Floating point values are converted to the small number of digits needed to represent the value. // - Spacing around punctuation marks may be lost. // - Parentheses will only be applied when they affect operator precedence. // // This function optionally takes in one or more UnparserOption to alter the unparsing behavior, such as // performing word wrapping on expressions. func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) (string, error) { unparserOpts := &unparserOption{ wrapOnColumn: defaultWrapOnColumn, wrapAfterColumnLimit: defaultWrapAfterColumnLimit, operatorsToWrapOn: defaultOperatorsToWrapOn, } var err error for _, opt := range opts { unparserOpts, err = opt(unparserOpts) if err != nil { return "", err } } un := &unparser{ info: info, options: unparserOpts, } err = un.visit(expr) if err != nil { return "", err } return un.str.String(), nil } // unparser visits an expression to reconstruct a human-readable string from an AST. type unparser struct { str strings.Builder info *exprpb.SourceInfo options *unparserOption lastWrappedIndex int } func (un *unparser) visit(expr *exprpb.Expr) error { if expr == nil { return errors.New("unsupported expression") } visited, err := un.visitMaybeMacroCall(expr) if visited || err != nil { return err } switch expr.GetExprKind().(type) { case *exprpb.Expr_CallExpr: return un.visitCall(expr) case *exprpb.Expr_ConstExpr: return un.visitConst(expr) case *exprpb.Expr_IdentExpr: return un.visitIdent(expr) case *exprpb.Expr_ListExpr: return un.visitList(expr) case *exprpb.Expr_SelectExpr: return un.visitSelect(expr) case *exprpb.Expr_StructExpr: return un.visitStruct(expr) default: return fmt.Errorf("unsupported expression: %v", expr) } } func (un *unparser) visitCall(expr *exprpb.Expr) error { c := expr.GetCallExpr() fun := c.GetFunction() switch fun { // ternary operator case operators.Conditional: return un.visitCallConditional(expr) // optional select operator case operators.OptSelect: return un.visitOptSelect(expr) // index operator case operators.Index: return un.visitCallIndex(expr) // optional index operator case operators.OptIndex: return un.visitCallOptIndex(expr) // unary operators case operators.LogicalNot, operators.Negate: return un.visitCallUnary(expr) // binary operators case operators.Add, operators.Divide, operators.Equals, operators.Greater, operators.GreaterEquals, operators.In, operators.Less, operators.LessEquals, operators.LogicalAnd, operators.LogicalOr, operators.Modulo, operators.Multiply, operators.NotEquals, operators.OldIn, operators.Subtract: return un.visitCallBinary(expr) // standard function calls. default: return un.visitCallFunc(expr) } } func (un *unparser) visitCallBinary(expr *exprpb.Expr) error { c := expr.GetCallExpr() fun := c.GetFunction() args := c.GetArgs() lhs := args[0] // add parens if the current operator is lower precedence than the lhs expr operator. lhsParen := isComplexOperatorWithRespectTo(fun, lhs) rhs := args[1] // add parens if the current operator is lower precedence than the rhs expr operator, // or the same precedence and the operator is left recursive. rhsParen := isComplexOperatorWithRespectTo(fun, rhs) if !rhsParen && isLeftRecursive(fun) { rhsParen = isSamePrecedence(fun, rhs) } err := un.visitMaybeNested(lhs, lhsParen) if err != nil { return err } unmangled, found := operators.FindReverseBinaryOperator(fun) if !found { return fmt.Errorf("cannot unmangle operator: %s", fun) } un.writeOperatorWithWrapping(fun, unmangled) return un.visitMaybeNested(rhs, rhsParen) } func (un *unparser) visitCallConditional(expr *exprpb.Expr) error { c := expr.GetCallExpr() args := c.GetArgs() // add parens if operand is a conditional itself. nested := isSamePrecedence(operators.Conditional, args[0]) || isComplexOperator(args[0]) err := un.visitMaybeNested(args[0], nested) if err != nil { return err } un.writeOperatorWithWrapping(operators.Conditional, "?") // add parens if operand is a conditional itself. nested = isSamePrecedence(operators.Conditional, args[1]) || isComplexOperator(args[1]) err = un.visitMaybeNested(args[1], nested) if err != nil { return err } un.str.WriteString(" : ") // add parens if operand is a conditional itself. nested = isSamePrecedence(operators.Conditional, args[2]) || isComplexOperator(args[2]) return un.visitMaybeNested(args[2], nested) } func (un *unparser) visitCallFunc(expr *exprpb.Expr) error { c := expr.GetCallExpr() fun := c.GetFunction() args := c.GetArgs() if c.GetTarget() != nil { nested := isBinaryOrTernaryOperator(c.GetTarget()) err := un.visitMaybeNested(c.GetTarget(), nested) if err != nil { return err } un.str.WriteString(".") } un.str.WriteString(fun) un.str.WriteString("(") for i, arg := range args { err := un.visit(arg) if err != nil { return err } if i < len(args)-1 { un.str.WriteString(", ") } } un.str.WriteString(")") return nil } func (un *unparser) visitCallIndex(expr *exprpb.Expr) error { return un.visitCallIndexInternal(expr, "[") } func (un *unparser) visitCallOptIndex(expr *exprpb.Expr) error { return un.visitCallIndexInternal(expr, "[?") } func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error { c := expr.GetCallExpr() args := c.GetArgs() nested := isBinaryOrTernaryOperator(args[0]) err := un.visitMaybeNested(args[0], nested) if err != nil { return err } un.str.WriteString(op) err = un.visit(args[1]) if err != nil { return err } un.str.WriteString("]") return nil } func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { c := expr.GetCallExpr() fun := c.GetFunction() args := c.GetArgs() unmangled, found := operators.FindReverse(fun) if !found { return fmt.Errorf("cannot unmangle operator: %s", fun) } un.str.WriteString(unmangled) nested := isComplexOperator(args[0]) return un.visitMaybeNested(args[0], nested) } func (un *unparser) visitConst(expr *exprpb.Expr) error { c := expr.GetConstExpr() switch c.GetConstantKind().(type) { case *exprpb.Constant_BoolValue: un.str.WriteString(strconv.FormatBool(c.GetBoolValue())) case *exprpb.Constant_BytesValue: // bytes constants are surrounded with b"" b := c.GetBytesValue() un.str.WriteString(`b"`) un.str.WriteString(bytesToOctets(b)) un.str.WriteString(`"`) case *exprpb.Constant_DoubleValue: // represent the float using the minimum required digits d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64) un.str.WriteString(d) if !strings.Contains(d, ".") { un.str.WriteString(".0") } case *exprpb.Constant_Int64Value: i := strconv.FormatInt(c.GetInt64Value(), 10) un.str.WriteString(i) case *exprpb.Constant_NullValue: un.str.WriteString("null") case *exprpb.Constant_StringValue: // strings will be double quoted with quotes escaped. un.str.WriteString(strconv.Quote(c.GetStringValue())) case *exprpb.Constant_Uint64Value: // uint literals have a 'u' suffix. ui := strconv.FormatUint(c.GetUint64Value(), 10) un.str.WriteString(ui) un.str.WriteString("u") default: return fmt.Errorf("unsupported constant: %v", expr) } return nil } func (un *unparser) visitIdent(expr *exprpb.Expr) error { un.str.WriteString(expr.GetIdentExpr().GetName()) return nil } func (un *unparser) visitList(expr *exprpb.Expr) error { l := expr.GetListExpr() elems := l.GetElements() optIndices := make(map[int]bool, len(elems)) for _, idx := range l.GetOptionalIndices() { optIndices[int(idx)] = true } un.str.WriteString("[") for i, elem := range elems { if optIndices[i] { un.str.WriteString("?") } err := un.visit(elem) if err != nil { return err } if i < len(elems)-1 { un.str.WriteString(", ") } } un.str.WriteString("]") return nil } func (un *unparser) visitOptSelect(expr *exprpb.Expr) error { c := expr.GetCallExpr() args := c.GetArgs() operand := args[0] field := args[1].GetConstExpr().GetStringValue() return un.visitSelectInternal(operand, false, ".?", field) } func (un *unparser) visitSelect(expr *exprpb.Expr) error { sel := expr.GetSelectExpr() return un.visitSelectInternal(sel.GetOperand(), sel.GetTestOnly(), ".", sel.GetField()) } func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op string, field string) error { // handle the case when the select expression was generated by the has() macro. if testOnly { un.str.WriteString("has(") } nested := !testOnly && isBinaryOrTernaryOperator(operand) err := un.visitMaybeNested(operand, nested) if err != nil { return err } un.str.WriteString(op) un.str.WriteString(field) if testOnly { un.str.WriteString(")") } return nil } func (un *unparser) visitStruct(expr *exprpb.Expr) error { s := expr.GetStructExpr() // If the message name is non-empty, then this should be treated as message construction. if s.GetMessageName() != "" { return un.visitStructMsg(expr) } // Otherwise, build a map. return un.visitStructMap(expr) } func (un *unparser) visitStructMsg(expr *exprpb.Expr) error { m := expr.GetStructExpr() entries := m.GetEntries() un.str.WriteString(m.GetMessageName()) un.str.WriteString("{") for i, entry := range entries { f := entry.GetFieldKey() if entry.GetOptionalEntry() { un.str.WriteString("?") } un.str.WriteString(f) un.str.WriteString(": ") v := entry.GetValue() err := un.visit(v) if err != nil { return err } if i < len(entries)-1 { un.str.WriteString(", ") } } un.str.WriteString("}") return nil } func (un *unparser) visitStructMap(expr *exprpb.Expr) error { m := expr.GetStructExpr() entries := m.GetEntries() un.str.WriteString("{") for i, entry := range entries { k := entry.GetMapKey() if entry.GetOptionalEntry() { un.str.WriteString("?") } err := un.visit(k) if err != nil { return err } un.str.WriteString(": ") v := entry.GetValue() err = un.visit(v) if err != nil { return err } if i < len(entries)-1 { un.str.WriteString(", ") } } un.str.WriteString("}") return nil } func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) { macroCalls := un.info.GetMacroCalls() call, found := macroCalls[expr.GetId()] if !found { return false, nil } return true, un.visit(call) } func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error { if nested { un.str.WriteString("(") } err := un.visit(expr) if err != nil { return err } if nested { un.str.WriteString(")") } return nil } // isLeftRecursive indicates whether the parser resolves the call in a left-recursive manner as // this can have an effect of how parentheses affect the order of operations in the AST. func isLeftRecursive(op string) bool { return op != operators.LogicalAnd && op != operators.LogicalOr } // isSamePrecedence indicates whether the precedence of the input operator is the same as the // precedence of the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. func isSamePrecedence(op string, expr *exprpb.Expr) bool { if expr.GetCallExpr() == nil { return false } c := expr.GetCallExpr() other := c.GetFunction() return operators.Precedence(op) == operators.Precedence(other) } // isLowerPrecedence indicates whether the precedence of the input operator is lower precedence // than the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. func isLowerPrecedence(op string, expr *exprpb.Expr) bool { c := expr.GetCallExpr() other := c.GetFunction() return operators.Precedence(op) < operators.Precedence(other) } // Indicates whether the expr is a complex operator, i.e., a call expression // with 2 or more arguments. func isComplexOperator(expr *exprpb.Expr) bool { if expr.GetCallExpr() != nil && len(expr.GetCallExpr().GetArgs()) >= 2 { return true } return false } // Indicates whether it is a complex operation compared to another. // expr is *not* considered complex if it is not a call expression or has // less than two arguments, or if it has a higher precedence than op. func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool { if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { return false } return isLowerPrecedence(op, expr) } // Indicate whether this is a binary or ternary operator. func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool { if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { return false } _, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction()) return isBinaryOp || isSamePrecedence(operators.Conditional, expr) } // bytesToOctets converts byte sequences to a string using a three digit octal encoded value // per byte. func bytesToOctets(byteVal []byte) string { var b strings.Builder for _, c := range byteVal { fmt.Fprintf(&b, "\\%03o", c) } return b.String() } // writeOperatorWithWrapping outputs the operator and inserts a newline for operators configured // in the unparser options. func (un *unparser) writeOperatorWithWrapping(fun string, unmangled string) bool { _, wrapOperatorExists := un.options.operatorsToWrapOn[fun] lineLength := un.str.Len() - un.lastWrappedIndex + len(fun) if wrapOperatorExists && lineLength >= un.options.wrapOnColumn { un.lastWrappedIndex = un.str.Len() // wrapAfterColumnLimit flag dictates whether the newline is placed // before or after the operator if un.options.wrapAfterColumnLimit { // Input: a && b // Output: a &&\nb un.str.WriteString(" ") un.str.WriteString(unmangled) un.str.WriteString("\n") } else { // Input: a && b // Output: a\n&& b un.str.WriteString("\n") un.str.WriteString(unmangled) un.str.WriteString(" ") } return true } un.str.WriteString(" ") un.str.WriteString(unmangled) un.str.WriteString(" ") return false } // Defined defaults for the unparser options var ( defaultWrapOnColumn = 80 defaultWrapAfterColumnLimit = true defaultOperatorsToWrapOn = map[string]bool{ operators.LogicalAnd: true, operators.LogicalOr: true, } ) // UnparserOption is a functional option for configuring the output formatting // of the Unparse function. type UnparserOption func(*unparserOption) (*unparserOption, error) // Internal representation of the UnparserOption type type unparserOption struct { wrapOnColumn int operatorsToWrapOn map[string]bool wrapAfterColumnLimit bool } // WrapOnColumn wraps the output expression when its string length exceeds a specified limit // for operators set by WrapOnOperators function or by default, "&&" and "||" will be wrapped. // // Example usage: // // Unparse(expr, sourceInfo, WrapOnColumn(40), WrapOnOperators(Operators.LogicalAnd)) // // This will insert a newline immediately after the logical AND operator for the below example input: // // Input: // 'my-principal-group' in request.auth.claims && request.auth.claims.iat > now - duration('5m') // // Output: // 'my-principal-group' in request.auth.claims && // request.auth.claims.iat > now - duration('5m') func WrapOnColumn(col int) UnparserOption { return func(opt *unparserOption) (*unparserOption, error) { if col < 1 { return nil, fmt.Errorf("Invalid unparser option. Wrap column value must be greater than or equal to 1. Got %v instead", col) } opt.wrapOnColumn = col return opt, nil } } // WrapOnOperators specifies which operators to perform word wrapping on an output expression when its string length // exceeds the column limit set by WrapOnColumn function. // // Word wrapping is supported on non-unary symbolic operators. Refer to operators.go for the full list // // This will replace any previously supplied operators instead of merging them. func WrapOnOperators(symbols ...string) UnparserOption { return func(opt *unparserOption) (*unparserOption, error) { opt.operatorsToWrapOn = make(map[string]bool) for _, symbol := range symbols { _, found := operators.FindReverse(symbol) if !found { return nil, fmt.Errorf("Invalid unparser option. Unsupported operator: %s", symbol) } arity := operators.Arity(symbol) if arity < 2 { return nil, fmt.Errorf("Invalid unparser option. Unary operators are unsupported: %s", symbol) } opt.operatorsToWrapOn[symbol] = true } return opt, nil } } // WrapAfterColumnLimit dictates whether to insert a newline before or after the specified operator // when word wrapping is performed. // // Example usage: // // Unparse(expr, sourceInfo, WrapOnColumn(40), WrapOnOperators(Operators.LogicalAnd), WrapAfterColumnLimit(false)) // // This will insert a newline immediately before the logical AND operator for the below example input, ensuring // that the length of a line never exceeds the specified column limit: // // Input: // 'my-principal-group' in request.auth.claims && request.auth.claims.iat > now - duration('5m') // // Output: // 'my-principal-group' in request.auth.claims // && request.auth.claims.iat > now - duration('5m') func WrapAfterColumnLimit(wrapAfter bool) UnparserOption { return func(opt *unparserOption) (*unparserOption, error) { opt.wrapAfterColumnLimit = wrapAfter return opt, nil } }