// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cel

import (
	"github.com/google/cel-go/common"
	"github.com/google/cel-go/common/ast"
	"github.com/google/cel-go/common/types"
	"github.com/google/cel-go/common/types/ref"
)

// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order.
//
// The static optimizer normalizes expression ids and type-checking run between optimization
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
	optimizers []ASTOptimizer
}

// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
	return &StaticOptimizer{
		optimizers: optimizers,
	}
}

// Optimize applies a sequence of optimizations to an Ast within a given environment.
//
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
	// Make a copy of the AST to be optimized.
	optimized := ast.Copy(a.impl)
	ids := newIDGenerator(ast.MaxID(a.impl))

	// Create the optimizer context, could be pooled in the future.
	issues := NewIssues(common.NewErrors(a.Source()))
	baseFac := ast.NewExprFactory()
	exprFac := &optimizerExprFactory{
		idGenerator: ids,
		fac:         baseFac,
		sourceInfo:  optimized.SourceInfo(),
	}
	ctx := &OptimizerContext{
		optimizerExprFactory: exprFac,
		Env:                  env,
		Issues:               issues,
	}

	// Apply the optimizations sequentially.
	for _, o := range opt.optimizers {
		optimized = o.Optimize(ctx, optimized)
		if issues.Err() != nil {
			return nil, issues
		}
		// Normalize expression id metadata including coordination with macro call metadata.
		freshIDGen := newIDGenerator(0)
		info := optimized.SourceInfo()
		expr := optimized.Expr()
		normalizeIDs(freshIDGen.renumberStable, expr, info)
		cleanupMacroRefs(expr, info)

		// Recheck the updated expression for any possible type-agreement or validation errors.
		parsed := &Ast{
			source: a.Source(),
			impl:   ast.NewAST(expr, info)}
		checked, iss := ctx.Check(parsed)
		if iss.Err() != nil {
			return nil, iss
		}
		optimized = checked.impl
	}

	// Return the optimized result.
	return &Ast{
		source: a.Source(),
		impl:   optimized,
	}, nil
}

// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
	optimized.RenumberIDs(idGen)

	if len(info.MacroCalls()) == 0 {
		return
	}

	// First, update the macro call ids themselves.
	callIDMap := map[int64]int64{}
	for id := range info.MacroCalls() {
		callIDMap[id] = idGen(id)
	}
	// Then update the macro call definitions which refer to these ids, but
	// ensure that the updates don't collide and remove macro entries which haven't
	// been visited / updated yet.
	type macroUpdate struct {
		id   int64
		call ast.Expr
	}
	macroUpdates := []macroUpdate{}
	for oldID, newID := range callIDMap {
		call, found := info.GetMacroCall(oldID)
		if !found {
			continue
		}
		call.RenumberIDs(idGen)
		macroUpdates = append(macroUpdates, macroUpdate{id: newID, call: call})
		info.ClearMacroCall(oldID)
	}
	for _, u := range macroUpdates {
		info.SetMacroCall(u.id, u.call)
	}
}

func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
	if len(info.MacroCalls()) == 0 {
		return
	}
	// Sanitize the macro call references once the optimized expression has been computed
	// and the ids normalized between the expression and the macros.
	exprRefMap := make(map[int64]struct{})
	ast.PostOrderVisit(expr, ast.NewExprVisitor(func(e ast.Expr) {
		if e.ID() == 0 {
			return
		}
		exprRefMap[e.ID()] = struct{}{}
	}))
	// Update the macro call id references to ensure that macro pointers are
	// updated consistently across macros.
	for _, call := range info.MacroCalls() {
		ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) {
			if e.ID() == 0 {
				return
			}
			exprRefMap[e.ID()] = struct{}{}
		}))
	}
	for id := range info.MacroCalls() {
		if _, found := exprRefMap[id]; !found {
			info.ClearMacroCall(id)
		}
	}
}

// newIDGenerator ensures that new ids are only created the first time they are encountered.
func newIDGenerator(seed int64) *idGenerator {
	return &idGenerator{
		idMap: make(map[int64]int64),
		seed:  seed,
	}
}

type idGenerator struct {
	idMap map[int64]int64
	seed  int64
}

func (gen *idGenerator) nextID() int64 {
	gen.seed++
	return gen.seed
}

func (gen *idGenerator) renumberStable(id int64) int64 {
	if id == 0 {
		return 0
	}
	if newID, found := gen.idMap[id]; found {
		return newID
	}
	nextID := gen.nextID()
	gen.idMap[id] = nextID
	return nextID
}

// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate
// subexpressions and report any errors encountered along the way. The context also embeds the
// optimizerExprFactory which can be used to generate new sub-expressions with expression ids
// consistent with the expectations of a parsed expression.
type OptimizerContext struct {
	*Env
	*optimizerExprFactory
	*Issues
}

// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
	// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
	Optimize(*OptimizerContext, *ast.AST) *ast.AST
}

type optimizerExprFactory struct {
	*idGenerator
	fac        ast.ExprFactory
	sourceInfo *ast.SourceInfo
}

// NewAST creates an AST from the current expression using the tracked source info which
// is modified and managed by the OptimizerContext.
func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST {
	return ast.NewAST(expr, opt.sourceInfo)
}

// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the
// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions
// between copies.
//
// Use this method before attempting to merge the expression from AST into another.
func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) {
	idGen := newIDGenerator(opt.nextID())
	defer func() { opt.seed = idGen.nextID() }()
	copyExpr := opt.fac.CopyExpr(a.Expr())
	copyInfo := ast.CopySourceInfo(a.SourceInfo())
	normalizeIDs(idGen.renumberStable, copyExpr, copyInfo)
	return copyExpr, copyInfo
}

// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
	copyExpr, copyInfo := opt.CopyAST(a)
	for macroID, call := range copyInfo.MacroCalls() {
		opt.SetMacroCall(macroID, call)
	}
	return copyExpr
}

// ClearMacroCall clears the macro at the given expression id.
func (opt *optimizerExprFactory) ClearMacroCall(id int64) {
	opt.sourceInfo.ClearMacroCall(id)
}

// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info
// metadata.
func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
	opt.sourceInfo.SetMacroCall(id, expr)
}

// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
	varID := opt.nextID()
	remainingID := opt.nextID()
	remaining = opt.fac.CopyExpr(remaining)
	remaining.RenumberIDs(func(id int64) int64 {
		if id == macroID {
			return remainingID
		}
		return id
	})
	if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists {
		opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
	}

	astExpr = opt.fac.NewComprehension(macroID,
		opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}),
		"#unused",
		varName,
		opt.fac.CopyExpr(varInit),
		opt.fac.NewLiteral(opt.nextID(), types.False),
		opt.fac.NewIdent(varID, varName),
		remaining)

	macroExpr = opt.fac.NewMemberCall(0, "bind",
		opt.fac.NewIdent(opt.nextID(), "cel"),
		opt.fac.NewIdent(varID, varName),
		opt.fac.CopyExpr(varInit),
		opt.fac.CopyExpr(remaining))
	opt.sanitizeMacro(macroID, macroExpr)
	return
}

// NewCall creates a global function call invocation expression.
//
// Example:
//
// countByField(list, fieldName)
// - function: countByField
// - args: [list, fieldName]
func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr {
	return opt.fac.NewCall(opt.nextID(), function, args...)
}

// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call.
//
// Example:
//
// list.countByField(fieldName)
// - function: countByField
// - target: list
// - args: [fieldName]
func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
	return opt.fac.NewMemberCall(opt.nextID(), function, target, args...)
}

// NewIdent creates a new identifier expression.
//
// Examples:
//
// - simple_var_name
// - qualified.subpackage.var_name
func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr {
	return opt.fac.NewIdent(opt.nextID(), name)
}

// NewLiteral creates a new literal expression value.
//
// The range of valid values for a literal generated during optimization is different than for expressions
// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can
// be converted back to a literal-like form.
func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr {
	return opt.fac.NewLiteral(opt.nextID(), value)
}

// NewList creates a list expression with a set of optional indices.
//
// Examples:
//
// [a, b]
// - elems: [a, b]
// - optIndices: []
//
// [a, ?b, ?c]
// - elems: [a, b, c]
// - optIndices: [1, 2]
func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr {
	return opt.fac.NewList(opt.nextID(), elems, optIndices)
}

// NewMap creates a map from a set of entry expressions which contain a key and value expression.
func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr {
	return opt.fac.NewMap(opt.nextID(), entries)
}

// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the
// entry is optional.
//
// Examples:
//
// {a: b}
// - key: a
// - value: b
// - optional: false
//
// {?a: ?b}
// - key: a
// - value: b
// - optional: true
func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr {
	return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional)
}

// NewHasMacro generates a test-only select expression to be included within an AST and an unexpanded
// has() macro call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewHasMacro(macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) {
	sel := s.AsSelect()
	astExpr = opt.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName())
	macroExpr = opt.fac.NewCall(0, "has",
		opt.NewSelect(opt.fac.CopyExpr(sel.Operand()), sel.FieldName()))
	opt.sanitizeMacro(macroID, macroExpr)
	return
}

// NewSelect creates a select expression where a field value is selected from an operand.
//
// Example:
//
// msg.field_name
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr {
	return opt.fac.NewSelect(opt.nextID(), operand, field)
}

// NewStruct creates a new typed struct value with an set of field initializations.
//
// Example:
//
// pkg.TypeName{field: value}
// - typeName: pkg.TypeName
// - fields: [{field: value}]
func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr {
	return opt.fac.NewStruct(opt.nextID(), typeName, fields)
}

// NewStructField creates a struct field initialization.
//
// Examples:
//
// {count: 3u}
// - field: count
// - value: 3u
// - optional: false
//
// {?count: x}
// - field: count
// - value: x
// - optional: true
func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr {
	return opt.fac.NewStructField(opt.nextID(), field, value, isOptional)
}

// UpdateExpr updates the target expression with the updated content while preserving macro metadata.
//
// There are four scenarios during the update to consider:
// 1. target is not macro, updated is not macro
// 2. target is macro, updated is not macro
// 3. target is macro, updated is macro
// 4. target is not macro, updated is macro
//
// When the target is a macro already, it may either be updated to a new macro function
// body if the update is also a macro, or it may be removed altogether if the update is
// a macro.
//
// When the update is a macro, then the target references within other macros must be
// updated to point to the new updated macro. Otherwise, other macros which pointed to
// the target body must be replaced with copies of the updated expression body.
func (opt *optimizerExprFactory) UpdateExpr(target, updated ast.Expr) {
	// Update the expression
	target.SetKindCase(updated)

	// Early return if there's no macros present sa the source info reflects the
	// macro set from the target and updated expressions.
	if len(opt.sourceInfo.MacroCalls()) == 0 {
		return
	}
	// Determine whether the target expression was a macro.
	_, targetIsMacro := opt.sourceInfo.GetMacroCall(target.ID())

	// Determine whether the updated expression was a macro.
	updatedMacro, updatedIsMacro := opt.sourceInfo.GetMacroCall(updated.ID())

	if updatedIsMacro {
		// If the updated call was a macro, then updated id maps to target id,
		// and the updated macro moves into the target id slot.
		opt.sourceInfo.ClearMacroCall(updated.ID())
		opt.sourceInfo.SetMacroCall(target.ID(), updatedMacro)
	} else if targetIsMacro {
		// Otherwise if the target expr was a macro, but is no longer, clear
		// the macro reference.
		opt.sourceInfo.ClearMacroCall(target.ID())
	}

	// Punch holes in the updated value where macros references exist.
	macroExpr := opt.fac.CopyExpr(target)
	macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
		if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists {
			e.SetKindCase(nil)
		}
	})
	ast.PostOrderVisit(macroExpr, macroRefVisitor)

	// Update any references to the expression within a macro
	macroVisitor := ast.NewExprVisitor(func(call ast.Expr) {
		// Update the target expression to point to the macro expression which
		// will be empty if the updated expression was a macro.
		if call.ID() == target.ID() {
			call.SetKindCase(opt.fac.CopyExpr(macroExpr))
		}
		// Update the macro call expression if it refers to the updated expression
		// id which has since been remapped to the target id.
		if call.ID() == updated.ID() {
			// Either ensure the expression is a macro reference or a populated with
			// the relevant sub-expression if the updated expr was not a macro.
			if updatedIsMacro {
				call.SetKindCase(nil)
			} else {
				call.SetKindCase(opt.fac.CopyExpr(macroExpr))
			}
			// Since SetKindCase does not renumber the id, ensure the references to
			// the old 'updated' id are mapped to the target id.
			call.RenumberIDs(func(id int64) int64 {
				if id == updated.ID() {
					return target.ID()
				}
				return id
			})
		}
	})
	for _, call := range opt.sourceInfo.MacroCalls() {
		ast.PostOrderVisit(call, macroVisitor)
	}
}

func (opt *optimizerExprFactory) sanitizeMacro(macroID int64, macroExpr ast.Expr) {
	macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
		if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID {
			e.SetKindCase(nil)
		}
	})
	ast.PostOrderVisit(macroExpr, macroRefVisitor)
}