// Copyright (c) Faye Amacker. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

package cbor

import (
	"bytes"
	"errors"
	"fmt"
	"reflect"
	"sort"
	"strconv"
	"strings"
	"sync"
)

type encodeFuncs struct {
	ef  encodeFunc
	ief isEmptyFunc
}

var (
	decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
	encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
	encodeFuncCache         sync.Map // map[reflect.Type]encodeFuncs
	typeInfoCache           sync.Map // map[reflect.Type]*typeInfo
)

type specialType int

const (
	specialTypeNone specialType = iota
	specialTypeUnmarshalerIface
	specialTypeEmptyIface
	specialTypeIface
	specialTypeTag
	specialTypeTime
)

type typeInfo struct {
	elemTypeInfo *typeInfo
	keyTypeInfo  *typeInfo
	typ          reflect.Type
	kind         reflect.Kind
	nonPtrType   reflect.Type
	nonPtrKind   reflect.Kind
	spclType     specialType
}

func newTypeInfo(t reflect.Type) *typeInfo {
	tInfo := typeInfo{typ: t, kind: t.Kind()}

	for t.Kind() == reflect.Ptr {
		t = t.Elem()
	}

	k := t.Kind()

	tInfo.nonPtrType = t
	tInfo.nonPtrKind = k

	if k == reflect.Interface {
		if t.NumMethod() == 0 {
			tInfo.spclType = specialTypeEmptyIface
		} else {
			tInfo.spclType = specialTypeIface
		}
	} else if t == typeTag {
		tInfo.spclType = specialTypeTag
	} else if t == typeTime {
		tInfo.spclType = specialTypeTime
	} else if reflect.PtrTo(t).Implements(typeUnmarshaler) {
		tInfo.spclType = specialTypeUnmarshalerIface
	}

	switch k {
	case reflect.Array, reflect.Slice:
		tInfo.elemTypeInfo = getTypeInfo(t.Elem())
	case reflect.Map:
		tInfo.keyTypeInfo = getTypeInfo(t.Key())
		tInfo.elemTypeInfo = getTypeInfo(t.Elem())
	}

	return &tInfo
}

type decodingStructType struct {
	fields             fields
	fieldIndicesByName map[string]int
	err                error
	toArray            bool
}

// The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
// here's a very basic implementation of an aggregated error.
type multierror []error

func (m multierror) Error() string {
	var sb strings.Builder
	for i, err := range m {
		sb.WriteString(err.Error())
		if i < len(m)-1 {
			sb.WriteString(", ")
		}
	}
	return sb.String()
}

func getDecodingStructType(t reflect.Type) *decodingStructType {
	if v, _ := decodingStructTypeCache.Load(t); v != nil {
		return v.(*decodingStructType)
	}

	flds, structOptions := getFields(t)

	toArray := hasToArrayOption(structOptions)

	var errs []error
	for i := 0; i < len(flds); i++ {
		if flds[i].keyAsInt {
			nameAsInt, numErr := strconv.Atoi(flds[i].name)
			if numErr != nil {
				errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
				break
			}
			flds[i].nameAsInt = int64(nameAsInt)
		}

		flds[i].typInfo = getTypeInfo(flds[i].typ)
	}

	fieldIndicesByName := make(map[string]int, len(flds))
	for i, fld := range flds {
		if _, ok := fieldIndicesByName[fld.name]; ok {
			errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
			continue
		}
		fieldIndicesByName[fld.name] = i
	}

	var err error
	{
		var multi multierror
		for _, each := range errs {
			if each != nil {
				multi = append(multi, each)
			}
		}
		if len(multi) == 1 {
			err = multi[0]
		} else if len(multi) > 1 {
			err = multi
		}
	}

	structType := &decodingStructType{
		fields:             flds,
		fieldIndicesByName: fieldIndicesByName,
		err:                err,
		toArray:            toArray,
	}
	decodingStructTypeCache.Store(t, structType)
	return structType
}

type encodingStructType struct {
	fields             fields
	bytewiseFields     fields
	lengthFirstFields  fields
	omitEmptyFieldsIdx []int
	err                error
	toArray            bool
}

func (st *encodingStructType) getFields(em *encMode) fields {
	switch em.sort {
	case SortNone, SortFastShuffle:
		return st.fields
	case SortLengthFirst:
		return st.lengthFirstFields
	default:
		return st.bytewiseFields
	}
}

type bytewiseFieldSorter struct {
	fields fields
}

func (x *bytewiseFieldSorter) Len() int {
	return len(x.fields)
}

func (x *bytewiseFieldSorter) Swap(i, j int) {
	x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
}

func (x *bytewiseFieldSorter) Less(i, j int) bool {
	return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
}

type lengthFirstFieldSorter struct {
	fields fields
}

func (x *lengthFirstFieldSorter) Len() int {
	return len(x.fields)
}

func (x *lengthFirstFieldSorter) Swap(i, j int) {
	x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
}

func (x *lengthFirstFieldSorter) Less(i, j int) bool {
	if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
		return len(x.fields[i].cborName) < len(x.fields[j].cborName)
	}
	return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
}

func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
	if v, _ := encodingStructTypeCache.Load(t); v != nil {
		structType := v.(*encodingStructType)
		return structType, structType.err
	}

	flds, structOptions := getFields(t)

	if hasToArrayOption(structOptions) {
		return getEncodingStructToArrayType(t, flds)
	}

	var err error
	var hasKeyAsInt bool
	var hasKeyAsStr bool
	var omitEmptyIdx []int
	e := getEncodeBuffer()
	for i := 0; i < len(flds); i++ {
		// Get field's encodeFunc
		flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ)
		if flds[i].ef == nil {
			err = &UnsupportedTypeError{t}
			break
		}

		// Encode field name
		if flds[i].keyAsInt {
			nameAsInt, numErr := strconv.Atoi(flds[i].name)
			if numErr != nil {
				err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
				break
			}
			flds[i].nameAsInt = int64(nameAsInt)
			if nameAsInt >= 0 {
				encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
			} else {
				n := nameAsInt*(-1) - 1
				encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
			}
			flds[i].cborName = make([]byte, e.Len())
			copy(flds[i].cborName, e.Bytes())
			e.Reset()

			hasKeyAsInt = true
		} else {
			encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
			flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
			n := copy(flds[i].cborName, e.Bytes())
			copy(flds[i].cborName[n:], flds[i].name)
			e.Reset()

			// If cborName contains a text string, then cborNameByteString contains a
			// string that has the byte string major type but is otherwise identical to
			// cborName.
			flds[i].cborNameByteString = make([]byte, len(flds[i].cborName))
			copy(flds[i].cborNameByteString, flds[i].cborName)
			// Reset encoded CBOR type to byte string, preserving the "additional
			// information" bits:
			flds[i].cborNameByteString[0] = byte(cborTypeByteString) |
				getAdditionalInformation(flds[i].cborNameByteString[0])

			hasKeyAsStr = true
		}

		// Check if field can be omitted when empty
		if flds[i].omitEmpty {
			omitEmptyIdx = append(omitEmptyIdx, i)
		}
	}
	putEncodeBuffer(e)

	if err != nil {
		structType := &encodingStructType{err: err}
		encodingStructTypeCache.Store(t, structType)
		return structType, structType.err
	}

	// Sort fields by canonical order
	bytewiseFields := make(fields, len(flds))
	copy(bytewiseFields, flds)
	sort.Sort(&bytewiseFieldSorter{bytewiseFields})

	lengthFirstFields := bytewiseFields
	if hasKeyAsInt && hasKeyAsStr {
		lengthFirstFields = make(fields, len(flds))
		copy(lengthFirstFields, flds)
		sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
	}

	structType := &encodingStructType{
		fields:             flds,
		bytewiseFields:     bytewiseFields,
		lengthFirstFields:  lengthFirstFields,
		omitEmptyFieldsIdx: omitEmptyIdx,
	}

	encodingStructTypeCache.Store(t, structType)
	return structType, structType.err
}

func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
	for i := 0; i < len(flds); i++ {
		// Get field's encodeFunc
		flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ)
		if flds[i].ef == nil {
			structType := &encodingStructType{err: &UnsupportedTypeError{t}}
			encodingStructTypeCache.Store(t, structType)
			return structType, structType.err
		}
	}

	structType := &encodingStructType{
		fields:  flds,
		toArray: true,
	}
	encodingStructTypeCache.Store(t, structType)
	return structType, structType.err
}

func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc) {
	if v, _ := encodeFuncCache.Load(t); v != nil {
		fs := v.(encodeFuncs)
		return fs.ef, fs.ief
	}
	ef, ief := getEncodeFuncInternal(t)
	encodeFuncCache.Store(t, encodeFuncs{ef, ief})
	return ef, ief
}

func getTypeInfo(t reflect.Type) *typeInfo {
	if v, _ := typeInfoCache.Load(t); v != nil {
		return v.(*typeInfo)
	}
	tInfo := newTypeInfo(t)
	typeInfoCache.Store(t, tInfo)
	return tInfo
}

func hasToArrayOption(tag string) bool {
	s := ",toarray"
	idx := strings.Index(tag, s)
	return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
}