// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package checker

import (
	"fmt"
	"strings"

	"github.com/google/cel-go/common/containers"
	"github.com/google/cel-go/common/decls"
	"github.com/google/cel-go/common/overloads"
	"github.com/google/cel-go/common/types"
	"github.com/google/cel-go/parser"
)

type aggregateLiteralElementType int

const (
	dynElementType        aggregateLiteralElementType = iota
	homogenousElementType aggregateLiteralElementType = 1 << iota
)

var (
	crossTypeNumericComparisonOverloads = map[string]struct{}{
		// double <-> int | uint
		overloads.LessDoubleInt64:           {},
		overloads.LessDoubleUint64:          {},
		overloads.LessEqualsDoubleInt64:     {},
		overloads.LessEqualsDoubleUint64:    {},
		overloads.GreaterDoubleInt64:        {},
		overloads.GreaterDoubleUint64:       {},
		overloads.GreaterEqualsDoubleInt64:  {},
		overloads.GreaterEqualsDoubleUint64: {},
		// int <-> double | uint
		overloads.LessInt64Double:          {},
		overloads.LessInt64Uint64:          {},
		overloads.LessEqualsInt64Double:    {},
		overloads.LessEqualsInt64Uint64:    {},
		overloads.GreaterInt64Double:       {},
		overloads.GreaterInt64Uint64:       {},
		overloads.GreaterEqualsInt64Double: {},
		overloads.GreaterEqualsInt64Uint64: {},
		// uint <-> double | int
		overloads.LessUint64Double:          {},
		overloads.LessUint64Int64:           {},
		overloads.LessEqualsUint64Double:    {},
		overloads.LessEqualsUint64Int64:     {},
		overloads.GreaterUint64Double:       {},
		overloads.GreaterUint64Int64:        {},
		overloads.GreaterEqualsUint64Double: {},
		overloads.GreaterEqualsUint64Int64:  {},
	}
)

// Env is the environment for type checking.
//
// The Env is comprised of a container, type provider, declarations, and other related objects
// which can be used to assist with type-checking.
type Env struct {
	container           *containers.Container
	provider            types.Provider
	declarations        *Scopes
	aggLitElemType      aggregateLiteralElementType
	filteredOverloadIDs map[string]struct{}
}

// NewEnv returns a new *Env with the given parameters.
func NewEnv(container *containers.Container, provider types.Provider, opts ...Option) (*Env, error) {
	declarations := newScopes()
	declarations.Push()

	envOptions := &options{}
	for _, opt := range opts {
		if err := opt(envOptions); err != nil {
			return nil, err
		}
	}
	aggLitElemType := dynElementType
	if envOptions.homogeneousAggregateLiterals {
		aggLitElemType = homogenousElementType
	}
	filteredOverloadIDs := crossTypeNumericComparisonOverloads
	if envOptions.crossTypeNumericComparisons {
		filteredOverloadIDs = make(map[string]struct{})
	}
	if envOptions.validatedDeclarations != nil {
		declarations = envOptions.validatedDeclarations.Copy()
	}
	return &Env{
		container:           container,
		provider:            provider,
		declarations:        declarations,
		aggLitElemType:      aggLitElemType,
		filteredOverloadIDs: filteredOverloadIDs,
	}, nil
}

// AddIdents configures the checker with a list of variable declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddIdents(declarations ...*decls.VariableDecl) error {
	errMsgs := make([]errorMsg, 0)
	for _, d := range declarations {
		errMsgs = append(errMsgs, e.addIdent(d))
	}
	return formatError(errMsgs)
}

// AddFunctions configures the checker with a list of function declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddFunctions(declarations ...*decls.FunctionDecl) error {
	errMsgs := make([]errorMsg, 0)
	for _, d := range declarations {
		errMsgs = append(errMsgs, e.setFunction(d)...)
	}
	return formatError(errMsgs)
}

// LookupIdent returns a Decl proto for typeName as an identifier in the Env.
// Returns nil if no such identifier is found in the Env.
func (e *Env) LookupIdent(name string) *decls.VariableDecl {
	for _, candidate := range e.container.ResolveCandidateNames(name) {
		if ident := e.declarations.FindIdent(candidate); ident != nil {
			return ident
		}

		// Next try to import the name as a reference to a message type. If found,
		// the declaration is added to the outest (global) scope of the
		// environment, so next time we can access it faster.
		if t, found := e.provider.FindStructType(candidate); found {
			decl := decls.NewVariable(candidate, t)
			e.declarations.AddIdent(decl)
			return decl
		}

		if i, found := e.provider.FindIdent(candidate); found {
			if t, ok := i.(*types.Type); ok {
				decl := decls.NewVariable(candidate, types.NewTypeTypeWithParam(t))
				e.declarations.AddIdent(decl)
				return decl
			}
		}

		// Next try to import this as an enum value by splitting the name in a type prefix and
		// the enum inside.
		if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {
			decl := decls.NewConstant(candidate, types.IntType, enumValue)
			e.declarations.AddIdent(decl)
			return decl
		}
	}
	return nil
}

// LookupFunction returns a Decl proto for typeName as a function in env.
// Returns nil if no such function is found in env.
func (e *Env) LookupFunction(name string) *decls.FunctionDecl {
	for _, candidate := range e.container.ResolveCandidateNames(name) {
		if fn := e.declarations.FindFunction(candidate); fn != nil {
			return fn
		}
	}
	return nil
}

// setFunction adds the function Decl to the Env.
// Adds a function decl if one doesn't already exist, then adds all overloads from the Decl.
// If overload overlaps with an existing overload, adds to the errors  in the Env instead.
func (e *Env) setFunction(fn *decls.FunctionDecl) []errorMsg {
	errMsgs := make([]errorMsg, 0)
	current := e.declarations.FindFunction(fn.Name())
	if current != nil {
		var err error
		current, err = current.Merge(fn)
		if err != nil {
			return append(errMsgs, errorMsg(err.Error()))
		}
	} else {
		current = fn
	}
	for _, overload := range current.OverloadDecls() {
		for _, macro := range parser.AllMacros {
			if macro.Function() == current.Name() &&
				macro.IsReceiverStyle() == overload.IsMemberFunction() &&
				macro.ArgCount() == len(overload.ArgTypes()) {
				errMsgs = append(errMsgs, overlappingMacroError(current.Name(), macro.ArgCount()))
			}
		}
		if len(errMsgs) > 0 {
			return errMsgs
		}
	}
	e.declarations.SetFunction(current)
	return errMsgs
}

// addIdent adds the Decl to the declarations in the Env.
// Returns a non-empty errorMsg if the identifier is already declared in the scope.
func (e *Env) addIdent(decl *decls.VariableDecl) errorMsg {
	current := e.declarations.FindIdentInScope(decl.Name())
	if current != nil {
		if current.DeclarationIsEquivalent(decl) {
			return ""
		}
		return overlappingIdentifierError(decl.Name())
	}
	e.declarations.AddIdent(decl)
	return ""
}

// isOverloadDisabled returns whether the overloadID is disabled in the current environment.
func (e *Env) isOverloadDisabled(overloadID string) bool {
	_, found := e.filteredOverloadIDs[overloadID]
	return found
}

// validatedDeclarations returns a reference to the validated variable and function declaration scope stack.
// must be copied before use.
func (e *Env) validatedDeclarations() *Scopes {
	return e.declarations
}

// enterScope creates a new Env instance with a new innermost declaration scope.
func (e *Env) enterScope() *Env {
	childDecls := e.declarations.Push()
	return &Env{
		declarations:   childDecls,
		container:      e.container,
		provider:       e.provider,
		aggLitElemType: e.aggLitElemType,
	}
}

// exitScope creates a new Env instance with the nearest outer declaration scope.
func (e *Env) exitScope() *Env {
	parentDecls := e.declarations.Pop()
	return &Env{
		declarations:   parentDecls,
		container:      e.container,
		provider:       e.provider,
		aggLitElemType: e.aggLitElemType,
	}
}

// errorMsg is a type alias meant to represent error-based return values which
// may be accumulated into an error at a later point in execution.
type errorMsg string

func overlappingIdentifierError(name string) errorMsg {
	return errorMsg(fmt.Sprintf("overlapping identifier for name '%s'", name))
}

func overlappingMacroError(name string, argCount int) errorMsg {
	return errorMsg(fmt.Sprintf(
		"overlapping macro for name '%s' with %d args", name, argCount))
}

func formatError(errMsgs []errorMsg) error {
	errStrs := make([]string, 0)
	if len(errMsgs) > 0 {
		for i := 0; i < len(errMsgs); i++ {
			if errMsgs[i] != "" {
				errStrs = append(errStrs, string(errMsgs[i]))
			}
		}
	}
	if len(errStrs) > 0 {
		return fmt.Errorf("%s", strings.Join(errStrs, "\n"))
	}
	return nil
}