/*-
 * Copyright 2014 Square Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package jose

import (
	"crypto/elliptic"
	"crypto/x509"
	"encoding/base64"
	"errors"
	"fmt"

	"gopkg.in/square/go-jose.v2/json"
)

// KeyAlgorithm represents a key management algorithm.
type KeyAlgorithm string

// SignatureAlgorithm represents a signature (or MAC) algorithm.
type SignatureAlgorithm string

// ContentEncryption represents a content encryption algorithm.
type ContentEncryption string

// CompressionAlgorithm represents an algorithm used for plaintext compression.
type CompressionAlgorithm string

// ContentType represents type of the contained data.
type ContentType string

var (
	// ErrCryptoFailure represents an error in cryptographic primitive. This
	// occurs when, for example, a message had an invalid authentication tag or
	// could not be decrypted.
	ErrCryptoFailure = errors.New("square/go-jose: error in cryptographic primitive")

	// ErrUnsupportedAlgorithm indicates that a selected algorithm is not
	// supported. This occurs when trying to instantiate an encrypter for an
	// algorithm that is not yet implemented.
	ErrUnsupportedAlgorithm = errors.New("square/go-jose: unknown/unsupported algorithm")

	// ErrUnsupportedKeyType indicates that the given key type/format is not
	// supported. This occurs when trying to instantiate an encrypter and passing
	// it a key of an unrecognized type or with unsupported parameters, such as
	// an RSA private key with more than two primes.
	ErrUnsupportedKeyType = errors.New("square/go-jose: unsupported key type/format")

	// ErrInvalidKeySize indicates that the given key is not the correct size
	// for the selected algorithm. This can occur, for example, when trying to
	// encrypt with AES-256 but passing only a 128-bit key as input.
	ErrInvalidKeySize = errors.New("square/go-jose: invalid key size for algorithm")

	// ErrNotSupported serialization of object is not supported. This occurs when
	// trying to compact-serialize an object which can't be represented in
	// compact form.
	ErrNotSupported = errors.New("square/go-jose: compact serialization not supported for object")

	// ErrUnprotectedNonce indicates that while parsing a JWS or JWE object, a
	// nonce header parameter was included in an unprotected header object.
	ErrUnprotectedNonce = errors.New("square/go-jose: Nonce parameter included in unprotected header")
)

// Key management algorithms
const (
	ED25519            = KeyAlgorithm("ED25519")
	RSA1_5             = KeyAlgorithm("RSA1_5")             // RSA-PKCS1v1.5
	RSA_OAEP           = KeyAlgorithm("RSA-OAEP")           // RSA-OAEP-SHA1
	RSA_OAEP_256       = KeyAlgorithm("RSA-OAEP-256")       // RSA-OAEP-SHA256
	A128KW             = KeyAlgorithm("A128KW")             // AES key wrap (128)
	A192KW             = KeyAlgorithm("A192KW")             // AES key wrap (192)
	A256KW             = KeyAlgorithm("A256KW")             // AES key wrap (256)
	DIRECT             = KeyAlgorithm("dir")                // Direct encryption
	ECDH_ES            = KeyAlgorithm("ECDH-ES")            // ECDH-ES
	ECDH_ES_A128KW     = KeyAlgorithm("ECDH-ES+A128KW")     // ECDH-ES + AES key wrap (128)
	ECDH_ES_A192KW     = KeyAlgorithm("ECDH-ES+A192KW")     // ECDH-ES + AES key wrap (192)
	ECDH_ES_A256KW     = KeyAlgorithm("ECDH-ES+A256KW")     // ECDH-ES + AES key wrap (256)
	A128GCMKW          = KeyAlgorithm("A128GCMKW")          // AES-GCM key wrap (128)
	A192GCMKW          = KeyAlgorithm("A192GCMKW")          // AES-GCM key wrap (192)
	A256GCMKW          = KeyAlgorithm("A256GCMKW")          // AES-GCM key wrap (256)
	PBES2_HS256_A128KW = KeyAlgorithm("PBES2-HS256+A128KW") // PBES2 + HMAC-SHA256 + AES key wrap (128)
	PBES2_HS384_A192KW = KeyAlgorithm("PBES2-HS384+A192KW") // PBES2 + HMAC-SHA384 + AES key wrap (192)
	PBES2_HS512_A256KW = KeyAlgorithm("PBES2-HS512+A256KW") // PBES2 + HMAC-SHA512 + AES key wrap (256)
)

// Signature algorithms
const (
	EdDSA = SignatureAlgorithm("EdDSA")
	HS256 = SignatureAlgorithm("HS256") // HMAC using SHA-256
	HS384 = SignatureAlgorithm("HS384") // HMAC using SHA-384
	HS512 = SignatureAlgorithm("HS512") // HMAC using SHA-512
	RS256 = SignatureAlgorithm("RS256") // RSASSA-PKCS-v1.5 using SHA-256
	RS384 = SignatureAlgorithm("RS384") // RSASSA-PKCS-v1.5 using SHA-384
	RS512 = SignatureAlgorithm("RS512") // RSASSA-PKCS-v1.5 using SHA-512
	ES256 = SignatureAlgorithm("ES256") // ECDSA using P-256 and SHA-256
	ES384 = SignatureAlgorithm("ES384") // ECDSA using P-384 and SHA-384
	ES512 = SignatureAlgorithm("ES512") // ECDSA using P-521 and SHA-512
	PS256 = SignatureAlgorithm("PS256") // RSASSA-PSS using SHA256 and MGF1-SHA256
	PS384 = SignatureAlgorithm("PS384") // RSASSA-PSS using SHA384 and MGF1-SHA384
	PS512 = SignatureAlgorithm("PS512") // RSASSA-PSS using SHA512 and MGF1-SHA512
)

// Content encryption algorithms
const (
	A128CBC_HS256 = ContentEncryption("A128CBC-HS256") // AES-CBC + HMAC-SHA256 (128)
	A192CBC_HS384 = ContentEncryption("A192CBC-HS384") // AES-CBC + HMAC-SHA384 (192)
	A256CBC_HS512 = ContentEncryption("A256CBC-HS512") // AES-CBC + HMAC-SHA512 (256)
	A128GCM       = ContentEncryption("A128GCM")       // AES-GCM (128)
	A192GCM       = ContentEncryption("A192GCM")       // AES-GCM (192)
	A256GCM       = ContentEncryption("A256GCM")       // AES-GCM (256)
)

// Compression algorithms
const (
	NONE    = CompressionAlgorithm("")    // No compression
	DEFLATE = CompressionAlgorithm("DEF") // DEFLATE (RFC 1951)
)

// A key in the protected header of a JWS object. Use of the Header...
// constants is preferred to enhance type safety.
type HeaderKey string

const (
	HeaderType        HeaderKey = "typ" // string
	HeaderContentType           = "cty" // string

	// These are set by go-jose and shouldn't need to be set by consumers of the
	// library.
	headerAlgorithm   = "alg"  // string
	headerEncryption  = "enc"  // ContentEncryption
	headerCompression = "zip"  // CompressionAlgorithm
	headerCritical    = "crit" // []string

	headerAPU = "apu" // *byteBuffer
	headerAPV = "apv" // *byteBuffer
	headerEPK = "epk" // *JSONWebKey
	headerIV  = "iv"  // *byteBuffer
	headerTag = "tag" // *byteBuffer
	headerX5c = "x5c" // []*x509.Certificate

	headerJWK   = "jwk"   // *JSONWebKey
	headerKeyID = "kid"   // string
	headerNonce = "nonce" // string

	headerP2C = "p2c" // *byteBuffer (int)
	headerP2S = "p2s" // *byteBuffer ([]byte)

)

// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
//
// The decoding of the constituent items is deferred because we want to marshal
// some members into particular structs rather than generic maps, but at the
// same time we need to receive any extra fields unhandled by this library to
// pass through to consuming code in case it wants to examine them.
type rawHeader map[HeaderKey]*json.RawMessage

// Header represents the read-only JOSE header for JWE/JWS objects.
type Header struct {
	KeyID      string
	JSONWebKey *JSONWebKey
	Algorithm  string
	Nonce      string

	// Unverified certificate chain parsed from x5c header.
	certificates []*x509.Certificate

	// Any headers not recognised above get unmarshaled
	// from JSON in a generic manner and placed in this map.
	ExtraHeaders map[HeaderKey]interface{}
}

// Certificates verifies & returns the certificate chain present
// in the x5c header field of a message, if one was present. Returns
// an error if there was no x5c header present or the chain could
// not be validated with the given verify options.
func (h Header) Certificates(opts x509.VerifyOptions) ([][]*x509.Certificate, error) {
	if len(h.certificates) == 0 {
		return nil, errors.New("square/go-jose: no x5c header present in message")
	}

	leaf := h.certificates[0]
	if opts.Intermediates == nil {
		opts.Intermediates = x509.NewCertPool()
		for _, intermediate := range h.certificates[1:] {
			opts.Intermediates.AddCert(intermediate)
		}
	}

	return leaf.Verify(opts)
}

func (parsed rawHeader) set(k HeaderKey, v interface{}) error {
	b, err := json.Marshal(v)
	if err != nil {
		return err
	}

	parsed[k] = makeRawMessage(b)
	return nil
}

// getString gets a string from the raw JSON, defaulting to "".
func (parsed rawHeader) getString(k HeaderKey) string {
	v, ok := parsed[k]
	if !ok || v == nil {
		return ""
	}
	var s string
	err := json.Unmarshal(*v, &s)
	if err != nil {
		return ""
	}
	return s
}

// getByteBuffer gets a byte buffer from the raw JSON. Returns (nil, nil) if
// not specified.
func (parsed rawHeader) getByteBuffer(k HeaderKey) (*byteBuffer, error) {
	v := parsed[k]
	if v == nil {
		return nil, nil
	}
	var bb *byteBuffer
	err := json.Unmarshal(*v, &bb)
	if err != nil {
		return nil, err
	}
	return bb, nil
}

// getAlgorithm extracts parsed "alg" from the raw JSON as a KeyAlgorithm.
func (parsed rawHeader) getAlgorithm() KeyAlgorithm {
	return KeyAlgorithm(parsed.getString(headerAlgorithm))
}

// getSignatureAlgorithm extracts parsed "alg" from the raw JSON as a SignatureAlgorithm.
func (parsed rawHeader) getSignatureAlgorithm() SignatureAlgorithm {
	return SignatureAlgorithm(parsed.getString(headerAlgorithm))
}

// getEncryption extracts parsed "enc" from the raw JSON.
func (parsed rawHeader) getEncryption() ContentEncryption {
	return ContentEncryption(parsed.getString(headerEncryption))
}

// getCompression extracts parsed "zip" from the raw JSON.
func (parsed rawHeader) getCompression() CompressionAlgorithm {
	return CompressionAlgorithm(parsed.getString(headerCompression))
}

func (parsed rawHeader) getNonce() string {
	return parsed.getString(headerNonce)
}

// getEPK extracts parsed "epk" from the raw JSON.
func (parsed rawHeader) getEPK() (*JSONWebKey, error) {
	v := parsed[headerEPK]
	if v == nil {
		return nil, nil
	}
	var epk *JSONWebKey
	err := json.Unmarshal(*v, &epk)
	if err != nil {
		return nil, err
	}
	return epk, nil
}

// getAPU extracts parsed "apu" from the raw JSON.
func (parsed rawHeader) getAPU() (*byteBuffer, error) {
	return parsed.getByteBuffer(headerAPU)
}

// getAPV extracts parsed "apv" from the raw JSON.
func (parsed rawHeader) getAPV() (*byteBuffer, error) {
	return parsed.getByteBuffer(headerAPV)
}

// getIV extracts parsed "iv" frpom the raw JSON.
func (parsed rawHeader) getIV() (*byteBuffer, error) {
	return parsed.getByteBuffer(headerIV)
}

// getTag extracts parsed "tag" frpom the raw JSON.
func (parsed rawHeader) getTag() (*byteBuffer, error) {
	return parsed.getByteBuffer(headerTag)
}

// getJWK extracts parsed "jwk" from the raw JSON.
func (parsed rawHeader) getJWK() (*JSONWebKey, error) {
	v := parsed[headerJWK]
	if v == nil {
		return nil, nil
	}
	var jwk *JSONWebKey
	err := json.Unmarshal(*v, &jwk)
	if err != nil {
		return nil, err
	}
	return jwk, nil
}

// getCritical extracts parsed "crit" from the raw JSON. If omitted, it
// returns an empty slice.
func (parsed rawHeader) getCritical() ([]string, error) {
	v := parsed[headerCritical]
	if v == nil {
		return nil, nil
	}

	var q []string
	err := json.Unmarshal(*v, &q)
	if err != nil {
		return nil, err
	}
	return q, nil
}

// getS2C extracts parsed "p2c" from the raw JSON.
func (parsed rawHeader) getP2C() (int, error) {
	v := parsed[headerP2C]
	if v == nil {
		return 0, nil
	}

	var p2c int
	err := json.Unmarshal(*v, &p2c)
	if err != nil {
		return 0, err
	}
	return p2c, nil
}

// getS2S extracts parsed "p2s" from the raw JSON.
func (parsed rawHeader) getP2S() (*byteBuffer, error) {
	return parsed.getByteBuffer(headerP2S)
}

// sanitized produces a cleaned-up header object from the raw JSON.
func (parsed rawHeader) sanitized() (h Header, err error) {
	for k, v := range parsed {
		if v == nil {
			continue
		}
		switch k {
		case headerJWK:
			var jwk *JSONWebKey
			err = json.Unmarshal(*v, &jwk)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal JWK: %v: %#v", err, string(*v))
				return
			}
			h.JSONWebKey = jwk
		case headerKeyID:
			var s string
			err = json.Unmarshal(*v, &s)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal key ID: %v: %#v", err, string(*v))
				return
			}
			h.KeyID = s
		case headerAlgorithm:
			var s string
			err = json.Unmarshal(*v, &s)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal algorithm: %v: %#v", err, string(*v))
				return
			}
			h.Algorithm = s
		case headerNonce:
			var s string
			err = json.Unmarshal(*v, &s)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal nonce: %v: %#v", err, string(*v))
				return
			}
			h.Nonce = s
		case headerX5c:
			c := []string{}
			err = json.Unmarshal(*v, &c)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
				return
			}
			h.certificates, err = parseCertificateChain(c)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
				return
			}
		default:
			if h.ExtraHeaders == nil {
				h.ExtraHeaders = map[HeaderKey]interface{}{}
			}
			var v2 interface{}
			err = json.Unmarshal(*v, &v2)
			if err != nil {
				err = fmt.Errorf("failed to unmarshal value: %v: %#v", err, string(*v))
				return
			}
			h.ExtraHeaders[k] = v2
		}
	}
	return
}

func parseCertificateChain(chain []string) ([]*x509.Certificate, error) {
	out := make([]*x509.Certificate, len(chain))
	for i, cert := range chain {
		raw, err := base64.StdEncoding.DecodeString(cert)
		if err != nil {
			return nil, err
		}
		out[i], err = x509.ParseCertificate(raw)
		if err != nil {
			return nil, err
		}
	}
	return out, nil
}

func (dst rawHeader) isSet(k HeaderKey) bool {
	dvr := dst[k]
	if dvr == nil {
		return false
	}

	var dv interface{}
	err := json.Unmarshal(*dvr, &dv)
	if err != nil {
		return true
	}

	if dvStr, ok := dv.(string); ok {
		return dvStr != ""
	}

	return true
}

// Merge headers from src into dst, giving precedence to headers from l.
func (dst rawHeader) merge(src *rawHeader) {
	if src == nil {
		return
	}

	for k, v := range *src {
		if dst.isSet(k) {
			continue
		}

		dst[k] = v
	}
}

// Get JOSE name of curve
func curveName(crv elliptic.Curve) (string, error) {
	switch crv {
	case elliptic.P256():
		return "P-256", nil
	case elliptic.P384():
		return "P-384", nil
	case elliptic.P521():
		return "P-521", nil
	default:
		return "", fmt.Errorf("square/go-jose: unsupported/unknown elliptic curve")
	}
}

// Get size of curve in bytes
func curveSize(crv elliptic.Curve) int {
	bits := crv.Params().BitSize

	div := bits / 8
	mod := bits % 8

	if mod == 0 {
		return div
	}

	return div + 1
}

func makeRawMessage(b []byte) *json.RawMessage {
	rm := json.RawMessage(b)
	return &rm
}