/*-
 * Copyright 2014 Square Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package jose

import (
	"crypto/ecdsa"
	"crypto/rsa"
	"errors"
	"fmt"
	"reflect"

	"gopkg.in/square/go-jose.v2/json"
)

// Encrypter represents an encrypter which produces an encrypted JWE object.
type Encrypter interface {
	Encrypt(plaintext []byte) (*JSONWebEncryption, error)
	EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
	Options() EncrypterOptions
}

// A generic content cipher
type contentCipher interface {
	keySize() int
	encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
	decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
}

// A key generator (for generating/getting a CEK)
type keyGenerator interface {
	keySize() int
	genKey() ([]byte, rawHeader, error)
}

// A generic key encrypter
type keyEncrypter interface {
	encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
}

// A generic key decrypter
type keyDecrypter interface {
	decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
}

// A generic encrypter based on the given key encrypter and content cipher.
type genericEncrypter struct {
	contentAlg     ContentEncryption
	compressionAlg CompressionAlgorithm
	cipher         contentCipher
	recipients     []recipientKeyInfo
	keyGenerator   keyGenerator
	extraHeaders   map[HeaderKey]interface{}
}

type recipientKeyInfo struct {
	keyID        string
	keyAlg       KeyAlgorithm
	keyEncrypter keyEncrypter
}

// EncrypterOptions represents options that can be set on new encrypters.
type EncrypterOptions struct {
	Compression CompressionAlgorithm

	// Optional map of additional keys to be inserted into the protected header
	// of a JWS object. Some specifications which make use of JWS like to insert
	// additional values here. All values must be JSON-serializable.
	ExtraHeaders map[HeaderKey]interface{}
}

// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
// if necessary. It returns itself and so can be used in a fluent style.
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
	if eo.ExtraHeaders == nil {
		eo.ExtraHeaders = map[HeaderKey]interface{}{}
	}
	eo.ExtraHeaders[k] = v
	return eo
}

// WithContentType adds a content type ("cty") header and returns the updated
// EncrypterOptions.
func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
	return eo.WithHeader(HeaderContentType, contentType)
}

// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
	return eo.WithHeader(HeaderType, typ)
}

// Recipient represents an algorithm/key to encrypt messages to.
//
// PBES2Count and PBES2Salt correspond with the  "p2c" and "p2s" headers used
// on the password-based encryption algorithms PBES2-HS256+A128KW,
// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
// default of 100000 will be used for the count and a 128-bit random salt will
// be generated.
type Recipient struct {
	Algorithm  KeyAlgorithm
	Key        interface{}
	KeyID      string
	PBES2Count int
	PBES2Salt  []byte
}

// NewEncrypter creates an appropriate encrypter based on the key type
func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
	encrypter := &genericEncrypter{
		contentAlg: enc,
		recipients: []recipientKeyInfo{},
		cipher:     getContentCipher(enc),
	}
	if opts != nil {
		encrypter.compressionAlg = opts.Compression
		encrypter.extraHeaders = opts.ExtraHeaders
	}

	if encrypter.cipher == nil {
		return nil, ErrUnsupportedAlgorithm
	}

	var keyID string
	var rawKey interface{}
	switch encryptionKey := rcpt.Key.(type) {
	case JSONWebKey:
		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
	case *JSONWebKey:
		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
	case OpaqueKeyEncrypter:
		keyID, rawKey = encryptionKey.KeyID(), encryptionKey
	default:
		rawKey = encryptionKey
	}

	switch rcpt.Algorithm {
	case DIRECT:
		// Direct encryption mode must be treated differently
		if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
			return nil, ErrUnsupportedKeyType
		}
		if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
			return nil, ErrInvalidKeySize
		}
		encrypter.keyGenerator = staticKeyGenerator{
			key: rawKey.([]byte),
		}
		recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
		recipientInfo.keyID = keyID
		if rcpt.KeyID != "" {
			recipientInfo.keyID = rcpt.KeyID
		}
		encrypter.recipients = []recipientKeyInfo{recipientInfo}
		return encrypter, nil
	case ECDH_ES:
		// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
		typeOf := reflect.TypeOf(rawKey)
		if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
			return nil, ErrUnsupportedKeyType
		}
		encrypter.keyGenerator = ecKeyGenerator{
			size:      encrypter.cipher.keySize(),
			algID:     string(enc),
			publicKey: rawKey.(*ecdsa.PublicKey),
		}
		recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
		recipientInfo.keyID = keyID
		if rcpt.KeyID != "" {
			recipientInfo.keyID = rcpt.KeyID
		}
		encrypter.recipients = []recipientKeyInfo{recipientInfo}
		return encrypter, nil
	default:
		// Can just add a standard recipient
		encrypter.keyGenerator = randomKeyGenerator{
			size: encrypter.cipher.keySize(),
		}
		err := encrypter.addRecipient(rcpt)
		return encrypter, err
	}
}

// NewMultiEncrypter creates a multi-encrypter based on the given parameters
func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
	cipher := getContentCipher(enc)

	if cipher == nil {
		return nil, ErrUnsupportedAlgorithm
	}
	if rcpts == nil || len(rcpts) == 0 {
		return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
	}

	encrypter := &genericEncrypter{
		contentAlg: enc,
		recipients: []recipientKeyInfo{},
		cipher:     cipher,
		keyGenerator: randomKeyGenerator{
			size: cipher.keySize(),
		},
	}

	if opts != nil {
		encrypter.compressionAlg = opts.Compression
	}

	for _, recipient := range rcpts {
		err := encrypter.addRecipient(recipient)
		if err != nil {
			return nil, err
		}
	}

	return encrypter, nil
}

func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
	var recipientInfo recipientKeyInfo

	switch recipient.Algorithm {
	case DIRECT, ECDH_ES:
		return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
	}

	recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
	if recipient.KeyID != "" {
		recipientInfo.keyID = recipient.KeyID
	}

	switch recipient.Algorithm {
	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
		if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
			sr.p2c = recipient.PBES2Count
			sr.p2s = recipient.PBES2Salt
		}
	}

	if err == nil {
		ctx.recipients = append(ctx.recipients, recipientInfo)
	}
	return err
}

func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
	switch encryptionKey := encryptionKey.(type) {
	case *rsa.PublicKey:
		return newRSARecipient(alg, encryptionKey)
	case *ecdsa.PublicKey:
		return newECDHRecipient(alg, encryptionKey)
	case []byte:
		return newSymmetricRecipient(alg, encryptionKey)
	case string:
		return newSymmetricRecipient(alg, []byte(encryptionKey))
	case *JSONWebKey:
		recipient, err := makeJWERecipient(alg, encryptionKey.Key)
		recipient.keyID = encryptionKey.KeyID
		return recipient, err
	}
	if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
		return newOpaqueKeyEncrypter(alg, encrypter)
	}
	return recipientKeyInfo{}, ErrUnsupportedKeyType
}

// newDecrypter creates an appropriate decrypter based on the key type
func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
	switch decryptionKey := decryptionKey.(type) {
	case *rsa.PrivateKey:
		return &rsaDecrypterSigner{
			privateKey: decryptionKey,
		}, nil
	case *ecdsa.PrivateKey:
		return &ecDecrypterSigner{
			privateKey: decryptionKey,
		}, nil
	case []byte:
		return &symmetricKeyCipher{
			key: decryptionKey,
		}, nil
	case string:
		return &symmetricKeyCipher{
			key: []byte(decryptionKey),
		}, nil
	case JSONWebKey:
		return newDecrypter(decryptionKey.Key)
	case *JSONWebKey:
		return newDecrypter(decryptionKey.Key)
	}
	if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
		return &opaqueKeyDecrypter{decrypter: okd}, nil
	}
	return nil, ErrUnsupportedKeyType
}

// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
	return ctx.EncryptWithAuthData(plaintext, nil)
}

// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
	obj := &JSONWebEncryption{}
	obj.aad = aad

	obj.protected = &rawHeader{}
	err := obj.protected.set(headerEncryption, ctx.contentAlg)
	if err != nil {
		return nil, err
	}

	obj.recipients = make([]recipientInfo, len(ctx.recipients))

	if len(ctx.recipients) == 0 {
		return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
	}

	cek, headers, err := ctx.keyGenerator.genKey()
	if err != nil {
		return nil, err
	}

	obj.protected.merge(&headers)

	for i, info := range ctx.recipients {
		recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
		if err != nil {
			return nil, err
		}

		err = recipient.header.set(headerAlgorithm, info.keyAlg)
		if err != nil {
			return nil, err
		}

		if info.keyID != "" {
			err = recipient.header.set(headerKeyID, info.keyID)
			if err != nil {
				return nil, err
			}
		}
		obj.recipients[i] = recipient
	}

	if len(ctx.recipients) == 1 {
		// Move per-recipient headers into main protected header if there's
		// only a single recipient.
		obj.protected.merge(obj.recipients[0].header)
		obj.recipients[0].header = nil
	}

	if ctx.compressionAlg != NONE {
		plaintext, err = compress(ctx.compressionAlg, plaintext)
		if err != nil {
			return nil, err
		}

		err = obj.protected.set(headerCompression, ctx.compressionAlg)
		if err != nil {
			return nil, err
		}
	}

	for k, v := range ctx.extraHeaders {
		b, err := json.Marshal(v)
		if err != nil {
			return nil, err
		}
		(*obj.protected)[k] = makeRawMessage(b)
	}

	authData := obj.computeAuthData()
	parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
	if err != nil {
		return nil, err
	}

	obj.iv = parts.iv
	obj.ciphertext = parts.ciphertext
	obj.tag = parts.tag

	return obj, nil
}

func (ctx *genericEncrypter) Options() EncrypterOptions {
	return EncrypterOptions{
		Compression:  ctx.compressionAlg,
		ExtraHeaders: ctx.extraHeaders,
	}
}

// Decrypt and validate the object and return the plaintext. Note that this
// function does not support multi-recipient, if you desire multi-recipient
// decryption use DecryptMulti instead.
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
	headers := obj.mergedHeaders(nil)

	if len(obj.recipients) > 1 {
		return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
	}

	critical, err := headers.getCritical()
	if err != nil {
		return nil, fmt.Errorf("square/go-jose: invalid crit header")
	}

	if len(critical) > 0 {
		return nil, fmt.Errorf("square/go-jose: unsupported crit header")
	}

	decrypter, err := newDecrypter(decryptionKey)
	if err != nil {
		return nil, err
	}

	cipher := getContentCipher(headers.getEncryption())
	if cipher == nil {
		return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
	}

	generator := randomKeyGenerator{
		size: cipher.keySize(),
	}

	parts := &aeadParts{
		iv:         obj.iv,
		ciphertext: obj.ciphertext,
		tag:        obj.tag,
	}

	authData := obj.computeAuthData()

	var plaintext []byte
	recipient := obj.recipients[0]
	recipientHeaders := obj.mergedHeaders(&recipient)

	cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
	if err == nil {
		// Found a valid CEK -- let's try to decrypt.
		plaintext, err = cipher.decrypt(cek, authData, parts)
	}

	if plaintext == nil {
		return nil, ErrCryptoFailure
	}

	// The "zip" header parameter may only be present in the protected header.
	if comp := obj.protected.getCompression(); comp != "" {
		plaintext, err = decompress(comp, plaintext)
	}

	return plaintext, err
}

// DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient,
// and the plaintext.
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
	globalHeaders := obj.mergedHeaders(nil)

	critical, err := globalHeaders.getCritical()
	if err != nil {
		return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
	}

	if len(critical) > 0 {
		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
	}

	decrypter, err := newDecrypter(decryptionKey)
	if err != nil {
		return -1, Header{}, nil, err
	}

	encryption := globalHeaders.getEncryption()
	cipher := getContentCipher(encryption)
	if cipher == nil {
		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
	}

	generator := randomKeyGenerator{
		size: cipher.keySize(),
	}

	parts := &aeadParts{
		iv:         obj.iv,
		ciphertext: obj.ciphertext,
		tag:        obj.tag,
	}

	authData := obj.computeAuthData()

	index := -1
	var plaintext []byte
	var headers rawHeader

	for i, recipient := range obj.recipients {
		recipientHeaders := obj.mergedHeaders(&recipient)

		cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
		if err == nil {
			// Found a valid CEK -- let's try to decrypt.
			plaintext, err = cipher.decrypt(cek, authData, parts)
			if err == nil {
				index = i
				headers = recipientHeaders
				break
			}
		}
	}

	if plaintext == nil || err != nil {
		return -1, Header{}, nil, ErrCryptoFailure
	}

	// The "zip" header parameter may only be present in the protected header.
	if comp := obj.protected.getCompression(); comp != "" {
		plaintext, err = decompress(comp, plaintext)
	}

	sanitized, err := headers.sanitized()
	if err != nil {
		return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
	}

	return index, sanitized, plaintext, err
}