Updated vednor files

This commit is contained in:
Serguei Bezverkhi
2018-02-15 08:50:31 -05:00
parent 18a4ce4439
commit 1f1e8cea37
3299 changed files with 834 additions and 1051200 deletions

View File

@ -1,30 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["ring_growing.go"],
importpath = "k8s.io/client-go/util/buffer",
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["ring_growing_test.go"],
importpath = "k8s.io/client-go/util/buffer",
library = ":go_default_library",
deps = ["//vendor/github.com/stretchr/testify/assert:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@ -1,72 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package buffer
// RingGrowing is a growing ring buffer.
// Not thread safe.
type RingGrowing struct {
data []interface{}
n int // Size of Data
beg int // First available element
readable int // Number of data items available
}
// NewRingGrowing constructs a new RingGrowing instance with provided parameters.
func NewRingGrowing(initialSize int) *RingGrowing {
return &RingGrowing{
data: make([]interface{}, initialSize),
n: initialSize,
}
}
// ReadOne reads (consumes) first item from the buffer if it is available, otherwise returns false.
func (r *RingGrowing) ReadOne() (data interface{}, ok bool) {
if r.readable == 0 {
return nil, false
}
r.readable--
element := r.data[r.beg]
r.data[r.beg] = nil // Remove reference to the object to help GC
if r.beg == r.n-1 {
// Was the last element
r.beg = 0
} else {
r.beg++
}
return element, true
}
// WriteOne adds an item to the end of the buffer, growing it if it is full.
func (r *RingGrowing) WriteOne(data interface{}) {
if r.readable == r.n {
// Time to grow
newN := r.n * 2
newData := make([]interface{}, newN)
to := r.beg + r.readable
if to <= r.n {
copy(newData, r.data[r.beg:to])
} else {
copied := copy(newData, r.data[r.beg:])
copy(newData[copied:], r.data[:(to%r.n)])
}
r.beg = 0
r.data = newData
r.n = newN
}
r.data[(r.readable+r.beg)%r.n] = data
r.readable++
}

View File

@ -1,50 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package buffer
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGrowth(t *testing.T) {
t.Parallel()
x := 10
g := NewRingGrowing(1)
for i := 0; i < x; i++ {
assert.Equal(t, i, g.readable)
g.WriteOne(i)
}
read := 0
for g.readable > 0 {
v, ok := g.ReadOne()
assert.True(t, ok)
assert.Equal(t, read, v)
read++
}
assert.Equalf(t, x, read, "expected to have read %d items: %d", x, read)
assert.Zerof(t, g.readable, "expected readable to be zero: %d", g.readable)
assert.Equalf(t, g.n, 16, "expected N to be 16: %d", g.n)
}
func TestEmpty(t *testing.T) {
t.Parallel()
g := NewRingGrowing(1)
_, ok := g.ReadOne()
assert.False(t, ok)
}

View File

@ -1,48 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"csr_test.go",
"pem_test.go",
],
data = glob(["testdata/**"]),
importpath = "k8s.io/client-go/util/cert",
library = ":go_default_library",
)
go_library(
name = "go_default_library",
srcs = [
"cert.go",
"csr.go",
"io.go",
"pem.go",
],
data = [
"testdata/dontUseThisKey.pem",
],
importpath = "k8s.io/client-go/util/cert",
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//staging/src/k8s.io/client-go/util/cert/triple:all-srcs",
],
tags = ["automanaged"],
)

View File

@ -1,215 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
cryptorand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math"
"math/big"
"net"
"time"
)
const (
rsaKeySize = 2048
duration365d = time.Hour * 24 * 365
)
// Config containes the basic fields required for creating a certificate
type Config struct {
CommonName string
Organization []string
AltNames AltNames
Usages []x509.ExtKeyUsage
}
// AltNames contains the domain names and IP addresses that will be added
// to the API Server's x509 certificate SubAltNames field. The values will
// be passed directly to the x509.Certificate object.
type AltNames struct {
DNSNames []string
IPs []net.IP
}
// NewPrivateKey creates an RSA private key
func NewPrivateKey() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(cryptorand.Reader, rsaKeySize)
}
// NewSelfSignedCACert creates a CA certificate
func NewSelfSignedCACert(cfg Config, key *rsa.PrivateKey) (*x509.Certificate, error) {
now := time.Now()
tmpl := x509.Certificate{
SerialNumber: new(big.Int).SetInt64(0),
Subject: pkix.Name{
CommonName: cfg.CommonName,
Organization: cfg.Organization,
},
NotBefore: now.UTC(),
NotAfter: now.Add(duration365d * 10).UTC(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDERBytes, err := x509.CreateCertificate(cryptorand.Reader, &tmpl, &tmpl, key.Public(), key)
if err != nil {
return nil, err
}
return x509.ParseCertificate(certDERBytes)
}
// NewSignedCert creates a signed certificate using the given CA certificate and key
func NewSignedCert(cfg Config, key *rsa.PrivateKey, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, error) {
serial, err := cryptorand.Int(cryptorand.Reader, new(big.Int).SetInt64(math.MaxInt64))
if err != nil {
return nil, err
}
if len(cfg.CommonName) == 0 {
return nil, errors.New("must specify a CommonName")
}
if len(cfg.Usages) == 0 {
return nil, errors.New("must specify at least one ExtKeyUsage")
}
certTmpl := x509.Certificate{
Subject: pkix.Name{
CommonName: cfg.CommonName,
Organization: cfg.Organization,
},
DNSNames: cfg.AltNames.DNSNames,
IPAddresses: cfg.AltNames.IPs,
SerialNumber: serial,
NotBefore: caCert.NotBefore,
NotAfter: time.Now().Add(duration365d).UTC(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: cfg.Usages,
}
certDERBytes, err := x509.CreateCertificate(cryptorand.Reader, &certTmpl, caCert, key.Public(), caKey)
if err != nil {
return nil, err
}
return x509.ParseCertificate(certDERBytes)
}
// MakeEllipticPrivateKeyPEM creates an ECDSA private key
func MakeEllipticPrivateKeyPEM() ([]byte, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), cryptorand.Reader)
if err != nil {
return nil, err
}
derBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, err
}
privateKeyPemBlock := &pem.Block{
Type: ECPrivateKeyBlockType,
Bytes: derBytes,
}
return pem.EncodeToMemory(privateKeyPemBlock), nil
}
// GenerateSelfSignedCertKey creates a self-signed certificate and key for the given host.
// Host may be an IP or a DNS name
// You may also specify additional subject alt names (either ip or dns names) for the certificate
func GenerateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 365),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, host)
}
template.IPAddresses = append(template.IPAddresses, alternateIPs...)
template.DNSNames = append(template.DNSNames, alternateDNS...)
derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, err
}
// Generate cert
certBuffer := bytes.Buffer{}
if err := pem.Encode(&certBuffer, &pem.Block{Type: CertificateBlockType, Bytes: derBytes}); err != nil {
return nil, nil, err
}
// Generate key
keyBuffer := bytes.Buffer{}
if err := pem.Encode(&keyBuffer, &pem.Block{Type: RSAPrivateKeyBlockType, Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
return nil, nil, err
}
return certBuffer.Bytes(), keyBuffer.Bytes(), nil
}
// FormatBytesCert receives byte array certificate and formats in human-readable format
func FormatBytesCert(cert []byte) (string, error) {
block, _ := pem.Decode(cert)
c, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", fmt.Errorf("failed to parse certificate [%v]", err)
}
return FormatCert(c), nil
}
// FormatCert receives certificate and formats in human-readable format
func FormatCert(c *x509.Certificate) string {
var ips []string
for _, ip := range c.IPAddresses {
ips = append(ips, ip.String())
}
altNames := append(ips, c.DNSNames...)
res := fmt.Sprintf(
"Issuer: CN=%s | Subject: CN=%s | CA: %t\n",
c.Issuer.CommonName, c.Subject.CommonName, c.IsCA,
)
res += fmt.Sprintf("Not before: %s Not After: %s", c.NotBefore, c.NotAfter)
if len(altNames) > 0 {
res += fmt.Sprintf("\nAlternate Names: %v", altNames)
}
return res
}

View File

@ -1,75 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
cryptorand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"net"
)
// MakeCSR generates a PEM-encoded CSR using the supplied private key, subject, and SANs.
// All key types that are implemented via crypto.Signer are supported (This includes *rsa.PrivateKey and *ecdsa.PrivateKey.)
func MakeCSR(privateKey interface{}, subject *pkix.Name, dnsSANs []string, ipSANs []net.IP) (csr []byte, err error) {
template := &x509.CertificateRequest{
Subject: *subject,
DNSNames: dnsSANs,
IPAddresses: ipSANs,
}
return MakeCSRFromTemplate(privateKey, template)
}
// MakeCSRFromTemplate generates a PEM-encoded CSR using the supplied private
// key and certificate request as a template. All key types that are
// implemented via crypto.Signer are supported (This includes *rsa.PrivateKey
// and *ecdsa.PrivateKey.)
func MakeCSRFromTemplate(privateKey interface{}, template *x509.CertificateRequest) ([]byte, error) {
t := *template
t.SignatureAlgorithm = sigType(privateKey)
csrDER, err := x509.CreateCertificateRequest(cryptorand.Reader, &t, privateKey)
if err != nil {
return nil, err
}
csrPemBlock := &pem.Block{
Type: CertificateRequestBlockType,
Bytes: csrDER,
}
return pem.EncodeToMemory(csrPemBlock), nil
}
func sigType(privateKey interface{}) x509.SignatureAlgorithm {
// Customize the signature for RSA keys, depending on the key size
if privateKey, ok := privateKey.(*rsa.PrivateKey); ok {
keySize := privateKey.N.BitLen()
switch {
case keySize >= 4096:
return x509.SHA512WithRSA
case keySize >= 3072:
return x509.SHA384WithRSA
default:
return x509.SHA256WithRSA
}
}
return x509.UnknownSignatureAlgorithm
}

View File

@ -1,75 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io/ioutil"
"net"
"testing"
)
func TestMakeCSR(t *testing.T) {
keyFile := "testdata/dontUseThisKey.pem"
subject := &pkix.Name{
CommonName: "kube-worker",
}
dnsSANs := []string{"localhost"}
ipSANs := []net.IP{net.ParseIP("127.0.0.1")}
keyData, err := ioutil.ReadFile(keyFile)
if err != nil {
t.Fatal(err)
}
key, err := ParsePrivateKeyPEM(keyData)
if err != nil {
t.Fatal(err)
}
csrPEM, err := MakeCSR(key, subject, dnsSANs, ipSANs)
if err != nil {
t.Error(err)
}
csrBlock, rest := pem.Decode(csrPEM)
if csrBlock == nil {
t.Error("Unable to decode MakeCSR result.")
}
if len(rest) != 0 {
t.Error("Found more than one PEM encoded block in the result.")
}
if csrBlock.Type != CertificateRequestBlockType {
t.Errorf("Found block type %q, wanted 'CERTIFICATE REQUEST'", csrBlock.Type)
}
csr, err := x509.ParseCertificateRequest(csrBlock.Bytes)
if err != nil {
t.Errorf("Found %v parsing MakeCSR result as a CertificateRequest.", err)
}
if csr.Subject.CommonName != subject.CommonName {
t.Errorf("Wanted %v, got %v", subject, csr.Subject)
}
if len(csr.DNSNames) != 1 {
t.Errorf("Wanted 1 DNS name in the result, got %d", len(csr.DNSNames))
} else if csr.DNSNames[0] != dnsSANs[0] {
t.Errorf("Wanted %v, got %v", dnsSANs[0], csr.DNSNames[0])
}
if len(csr.IPAddresses) != 1 {
t.Errorf("Wanted 1 IP address in the result, got %d", len(csr.IPAddresses))
} else if csr.IPAddresses[0].String() != ipSANs[0].String() {
t.Errorf("Wanted %v, got %v", ipSANs[0], csr.IPAddresses[0])
}
}

View File

@ -1,158 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"path/filepath"
)
// CanReadCertAndKey returns true if the certificate and key files already exists,
// otherwise returns false. If lost one of cert and key, returns error.
func CanReadCertAndKey(certPath, keyPath string) (bool, error) {
certReadable := canReadFile(certPath)
keyReadable := canReadFile(keyPath)
if certReadable == false && keyReadable == false {
return false, nil
}
if certReadable == false {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", certPath)
}
if keyReadable == false {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", keyPath)
}
return true, nil
}
// If the file represented by path exists and
// readable, returns true otherwise returns false.
func canReadFile(path string) bool {
f, err := os.Open(path)
if err != nil {
return false
}
defer f.Close()
return true
}
// WriteCert writes the pem-encoded certificate data to certPath.
// The certificate file will be created with file mode 0644.
// If the certificate file already exists, it will be overwritten.
// The parent directory of the certPath will be created as needed with file mode 0755.
func WriteCert(certPath string, data []byte) error {
if err := os.MkdirAll(filepath.Dir(certPath), os.FileMode(0755)); err != nil {
return err
}
return ioutil.WriteFile(certPath, data, os.FileMode(0644))
}
// WriteKey writes the pem-encoded key data to keyPath.
// The key file will be created with file mode 0600.
// If the key file already exists, it will be overwritten.
// The parent directory of the keyPath will be created as needed with file mode 0755.
func WriteKey(keyPath string, data []byte) error {
if err := os.MkdirAll(filepath.Dir(keyPath), os.FileMode(0755)); err != nil {
return err
}
return ioutil.WriteFile(keyPath, data, os.FileMode(0600))
}
// LoadOrGenerateKeyFile looks for a key in the file at the given path. If it
// can't find one, it will generate a new key and store it there.
func LoadOrGenerateKeyFile(keyPath string) (data []byte, wasGenerated bool, err error) {
loadedData, err := ioutil.ReadFile(keyPath)
if err == nil {
return loadedData, false, err
}
if !os.IsNotExist(err) {
return nil, false, fmt.Errorf("error loading key from %s: %v", keyPath, err)
}
generatedData, err := MakeEllipticPrivateKeyPEM()
if err != nil {
return nil, false, fmt.Errorf("error generating key: %v", err)
}
if err := WriteKey(keyPath, generatedData); err != nil {
return nil, false, fmt.Errorf("error writing key to %s: %v", keyPath, err)
}
return generatedData, true, nil
}
// NewPool returns an x509.CertPool containing the certificates in the given PEM-encoded file.
// Returns an error if the file could not be read, a certificate could not be parsed, or if the file does not contain any certificates
func NewPool(filename string) (*x509.CertPool, error) {
certs, err := CertsFromFile(filename)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
for _, cert := range certs {
pool.AddCert(cert)
}
return pool, nil
}
// CertsFromFile returns the x509.Certificates contained in the given PEM-encoded file.
// Returns an error if the file could not be read, a certificate could not be parsed, or if the file does not contain any certificates
func CertsFromFile(file string) ([]*x509.Certificate, error) {
pemBlock, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
certs, err := ParseCertsPEM(pemBlock)
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", file, err)
}
return certs, nil
}
// PrivateKeyFromFile returns the private key in rsa.PrivateKey or ecdsa.PrivateKey format from a given PEM-encoded file.
// Returns an error if the file could not be read or if the private key could not be parsed.
func PrivateKeyFromFile(file string) (interface{}, error) {
data, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
key, err := ParsePrivateKeyPEM(data)
if err != nil {
return nil, fmt.Errorf("error reading private key file %s: %v", file, err)
}
return key, nil
}
// PublicKeysFromFile returns the public keys in rsa.PublicKey or ecdsa.PublicKey format from a given PEM-encoded file.
// Reads public keys from both public and private key files.
func PublicKeysFromFile(file string) ([]interface{}, error) {
data, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
keys, err := ParsePublicKeysPEM(data)
if err != nil {
return nil, fmt.Errorf("error reading public key file %s: %v", file, err)
}
return keys, nil
}

View File

@ -1,269 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
)
const (
// ECPrivateKeyBlockType is a possible value for pem.Block.Type.
ECPrivateKeyBlockType = "EC PRIVATE KEY"
// RSAPrivateKeyBlockType is a possible value for pem.Block.Type.
RSAPrivateKeyBlockType = "RSA PRIVATE KEY"
// PrivateKeyBlockType is a possible value for pem.Block.Type.
PrivateKeyBlockType = "PRIVATE KEY"
// PublicKeyBlockType is a possible value for pem.Block.Type.
PublicKeyBlockType = "PUBLIC KEY"
// CertificateBlockType is a possible value for pem.Block.Type.
CertificateBlockType = "CERTIFICATE"
// CertificateRequestBlockType is a possible value for pem.Block.Type.
CertificateRequestBlockType = "CERTIFICATE REQUEST"
)
// EncodePublicKeyPEM returns PEM-encoded public data
func EncodePublicKeyPEM(key *rsa.PublicKey) ([]byte, error) {
der, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return []byte{}, err
}
block := pem.Block{
Type: PublicKeyBlockType,
Bytes: der,
}
return pem.EncodeToMemory(&block), nil
}
// EncodePrivateKeyPEM returns PEM-encoded private key data
func EncodePrivateKeyPEM(key *rsa.PrivateKey) []byte {
block := pem.Block{
Type: RSAPrivateKeyBlockType,
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
return pem.EncodeToMemory(&block)
}
// EncodeCertPEM returns PEM-endcoded certificate data
func EncodeCertPEM(cert *x509.Certificate) []byte {
block := pem.Block{
Type: CertificateBlockType,
Bytes: cert.Raw,
}
return pem.EncodeToMemory(&block)
}
// ParsePrivateKeyPEM returns a private key parsed from a PEM block in the supplied data.
// Recognizes PEM blocks for "EC PRIVATE KEY", "RSA PRIVATE KEY", or "PRIVATE KEY"
func ParsePrivateKeyPEM(keyData []byte) (interface{}, error) {
var privateKeyPemBlock *pem.Block
for {
privateKeyPemBlock, keyData = pem.Decode(keyData)
if privateKeyPemBlock == nil {
break
}
switch privateKeyPemBlock.Type {
case ECPrivateKeyBlockType:
// ECDSA Private Key in ASN.1 format
if key, err := x509.ParseECPrivateKey(privateKeyPemBlock.Bytes); err == nil {
return key, nil
}
case RSAPrivateKeyBlockType:
// RSA Private Key in PKCS#1 format
if key, err := x509.ParsePKCS1PrivateKey(privateKeyPemBlock.Bytes); err == nil {
return key, nil
}
case PrivateKeyBlockType:
// RSA or ECDSA Private Key in unencrypted PKCS#8 format
if key, err := x509.ParsePKCS8PrivateKey(privateKeyPemBlock.Bytes); err == nil {
return key, nil
}
}
// tolerate non-key PEM blocks for compatibility with things like "EC PARAMETERS" blocks
// originally, only the first PEM block was parsed and expected to be a key block
}
// we read all the PEM blocks and didn't recognize one
return nil, fmt.Errorf("data does not contain a valid RSA or ECDSA private key")
}
// ParsePublicKeysPEM is a helper function for reading an array of rsa.PublicKey or ecdsa.PublicKey from a PEM-encoded byte array.
// Reads public keys from both public and private key files.
func ParsePublicKeysPEM(keyData []byte) ([]interface{}, error) {
var block *pem.Block
keys := []interface{}{}
for {
// read the next block
block, keyData = pem.Decode(keyData)
if block == nil {
break
}
// test block against parsing functions
if privateKey, err := parseRSAPrivateKey(block.Bytes); err == nil {
keys = append(keys, &privateKey.PublicKey)
continue
}
if publicKey, err := parseRSAPublicKey(block.Bytes); err == nil {
keys = append(keys, publicKey)
continue
}
if privateKey, err := parseECPrivateKey(block.Bytes); err == nil {
keys = append(keys, &privateKey.PublicKey)
continue
}
if publicKey, err := parseECPublicKey(block.Bytes); err == nil {
keys = append(keys, publicKey)
continue
}
// tolerate non-key PEM blocks for backwards compatibility
// originally, only the first PEM block was parsed and expected to be a key block
}
if len(keys) == 0 {
return nil, fmt.Errorf("data does not contain any valid RSA or ECDSA public keys")
}
return keys, nil
}
// ParseCertsPEM returns the x509.Certificates contained in the given PEM-encoded byte array
// Returns an error if a certificate could not be parsed, or if the data does not contain any certificates
func ParseCertsPEM(pemCerts []byte) ([]*x509.Certificate, error) {
ok := false
certs := []*x509.Certificate{}
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
// Only use PEM "CERTIFICATE" blocks without extra headers
if block.Type != CertificateBlockType || len(block.Headers) != 0 {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return certs, err
}
certs = append(certs, cert)
ok = true
}
if !ok {
return certs, errors.New("data does not contain any valid RSA or ECDSA certificates")
}
return certs, nil
}
// parseRSAPublicKey parses a single RSA public key from the provided data
func parseRSAPublicKey(data []byte) (*rsa.PublicKey, error) {
var err error
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(data); err != nil {
if cert, err := x509.ParseCertificate(data); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
// Test if parsed key is an RSA Public Key
var pubKey *rsa.PublicKey
var ok bool
if pubKey, ok = parsedKey.(*rsa.PublicKey); !ok {
return nil, fmt.Errorf("data doesn't contain valid RSA Public Key")
}
return pubKey, nil
}
// parseRSAPrivateKey parses a single RSA private key from the provided data
func parseRSAPrivateKey(data []byte) (*rsa.PrivateKey, error) {
var err error
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKCS1PrivateKey(data); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(data); err != nil {
return nil, err
}
}
// Test if parsed key is an RSA Private Key
var privKey *rsa.PrivateKey
var ok bool
if privKey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, fmt.Errorf("data doesn't contain valid RSA Private Key")
}
return privKey, nil
}
// parseECPublicKey parses a single ECDSA public key from the provided data
func parseECPublicKey(data []byte) (*ecdsa.PublicKey, error) {
var err error
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(data); err != nil {
if cert, err := x509.ParseCertificate(data); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
// Test if parsed key is an ECDSA Public Key
var pubKey *ecdsa.PublicKey
var ok bool
if pubKey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
return nil, fmt.Errorf("data doesn't contain valid ECDSA Public Key")
}
return pubKey, nil
}
// parseECPrivateKey parses a single ECDSA private key from the provided data
func parseECPrivateKey(data []byte) (*ecdsa.PrivateKey, error) {
var err error
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParseECPrivateKey(data); err != nil {
return nil, err
}
// Test if parsed key is an ECDSA Private Key
var privKey *ecdsa.PrivateKey
var ok bool
if privKey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
return nil, fmt.Errorf("data doesn't contain valid ECDSA Private Key")
}
return privKey, nil
}

View File

@ -1,197 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cert
import (
"io/ioutil"
"os"
"testing"
)
const (
// rsaPrivateKey is a RSA Private Key in PKCS#1 format
// openssl genrsa -out rsa2048.pem 2048
rsaPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA92mVjhBKOFsdxFzb/Pjq+7b5TJlODAdY5hK+WxLZTIrfhDPq
FWrGKdjSNiHbXrdEtwJh9V+RqPZVSN3aWy1224RgkyNdMJsXhJKuCC24ZKY8SXtW
xuTYmMRaMnCsv6QBGRTIbZ2EFbAObVM7lDyv1VqY3amZIWFQMlZ9CNpxDSPa5yi4
3gopbXkne0oGNmey9X0qtpk7NMZIgAL6Zz4rZ30bcfC2ag6RLOFI2E/c4n8c38R8
9MfXfLkj8/Cxo4JfI9NvRCpPOpFO8d/ZtWVUuIrBQN+Y7tkN2T60Qq/TkKXUrhDe
fwlTlktZVJ/GztLYU41b2GcWsh/XO+PH831rmwIDAQABAoIBAQCC9c6GDjVbM0/E
WurPMusfJjE7zII1d8YkspM0HfwLug6qKdikUYpnKC/NG4rEzfl/bbFwco/lgc6O
7W/hh2U8uQttlvCDA/Uk5YddKOZL0Hpk4vaB/SxxYK3luSKXpjY2knutGg2KdVCN
qdsFkkH4iyYTXuyBcMNEgedZQldI/kEujIH/L7FE+DF5TMzT4lHhozDoG+fy564q
qVGUZXJn0ubc3GaPn2QOLNNM44sfYA4UJCpKBXPu85bvNObjxVQO4WqwwxU1vRnL
UUsaGaelhSVJCo0dVPRvrfPPKZ09HTwpy40EkgQo6VriFc1EBoQDjENLbAJv9OfQ
aCc9wiZhAoGBAP/8oEy48Zbb0P8Vdy4djf5tfBW8yXFLWzXewJ4l3itKS1r42nbX
9q3cJsgRTQm8uRcMIpWxsc3n6zG+lREvTkoTB3ViI7+uQPiqA+BtWyNy7jzufFke
ONKZfg7QxxmYRWZBRnoNGNbMpNeERuLmhvQuom9D1WbhzAYJbfs/O4WTAoGBAPds
2FNDU0gaesFDdkIUGq1nIJqRQDW485LXZm4pFqBFxdOpbdWRuYT2XZjd3fD0XY98
Nhkpb7NTMCuK3BdKcqIptt+cK+quQgYid0hhhgZbpCQ5AL6c6KgyjgpYlh2enzU9
Zo3yg8ej1zbbA11sBlhX+5iO2P1u5DG+JHLwUUbZAoGAUwaU102EzfEtsA4+QW7E
hyjrfgFlNKHES4yb3K9bh57pIfBkqvcQwwMMcQdrfSUAw0DkVrjzel0mI1Q09QXq
1ould6UFAz55RC2gZEITtUOpkYmoOx9aPrQZ9qQwb1S77ZZuTVfCHqjxLhVxCFbM
npYhiQTvShciHTMhwMOZgpECgYAVV5EtVXBYltgh1YTc3EkUzgF087R7LdHsx6Gx
POATwRD4WfP8aQ58lpeqOPEM+LcdSlSMRRO6fyF3kAm+BJDwxfJdRWZQXumZB94M
I0VhRQRaj4Qt7PDwmTPBVrTUJzuKZxpyggm17b8Bn1Ch/VBqzGQKW8AB1E/grosM
UwhfuQKBgQC2JO/iqTQScHClf0qlItCJsBuVukFmSAVCkpOD8YdbdlPdOOwSk1wQ
C0eAlsC3BCMvkpidKQmra6IqIrvTGI6EFgkrb3aknWdup2w8j2udYCNqyE3W+fVe
p8FdYQ1FkACQ+daO5VlClL/9l0sGjKXlNKbpmJ2H4ngZmXj5uGmxuQ==
-----END RSA PRIVATE KEY-----`
// rsaPublicKey is a RSA Public Key in PEM encoded format
// openssl rsa -in rsa2048.pem -pubout -out rsa2048pub.pem
rsaPublicKey = `-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA92mVjhBKOFsdxFzb/Pjq
+7b5TJlODAdY5hK+WxLZTIrfhDPqFWrGKdjSNiHbXrdEtwJh9V+RqPZVSN3aWy12
24RgkyNdMJsXhJKuCC24ZKY8SXtWxuTYmMRaMnCsv6QBGRTIbZ2EFbAObVM7lDyv
1VqY3amZIWFQMlZ9CNpxDSPa5yi43gopbXkne0oGNmey9X0qtpk7NMZIgAL6Zz4r
Z30bcfC2ag6RLOFI2E/c4n8c38R89MfXfLkj8/Cxo4JfI9NvRCpPOpFO8d/ZtWVU
uIrBQN+Y7tkN2T60Qq/TkKXUrhDefwlTlktZVJ/GztLYU41b2GcWsh/XO+PH831r
mwIDAQAB
-----END PUBLIC KEY-----`
// certificate is an x509 certificate in PEM encoded format
// openssl req -new -key rsa2048.pem -sha256 -nodes -x509 -days 1826 -out x509certificate.pem -subj "/C=US/CN=not-valid"
certificate = `-----BEGIN CERTIFICATE-----
MIIDFTCCAf2gAwIBAgIJAN8B8NOwtiUCMA0GCSqGSIb3DQEBCwUAMCExCzAJBgNV
BAYTAlVTMRIwEAYDVQQDDAlub3QtdmFsaWQwHhcNMTcwMzIyMDI1NjM2WhcNMjIw
MzIyMDI1NjM2WjAhMQswCQYDVQQGEwJVUzESMBAGA1UEAwwJbm90LXZhbGlkMIIB
IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA92mVjhBKOFsdxFzb/Pjq+7b5
TJlODAdY5hK+WxLZTIrfhDPqFWrGKdjSNiHbXrdEtwJh9V+RqPZVSN3aWy1224Rg
kyNdMJsXhJKuCC24ZKY8SXtWxuTYmMRaMnCsv6QBGRTIbZ2EFbAObVM7lDyv1VqY
3amZIWFQMlZ9CNpxDSPa5yi43gopbXkne0oGNmey9X0qtpk7NMZIgAL6Zz4rZ30b
cfC2ag6RLOFI2E/c4n8c38R89MfXfLkj8/Cxo4JfI9NvRCpPOpFO8d/ZtWVUuIrB
QN+Y7tkN2T60Qq/TkKXUrhDefwlTlktZVJ/GztLYU41b2GcWsh/XO+PH831rmwID
AQABo1AwTjAdBgNVHQ4EFgQU1I5GfinLF7ta+dBJ6UWcrYaexLswHwYDVR0jBBgw
FoAU1I5GfinLF7ta+dBJ6UWcrYaexLswDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0B
AQsFAAOCAQEAUl0wUD4y41juHFOVMYiziPYr1ShSpQXdwp8FfaHrzI5hsr8UMe8D
dzb9QzZ4bx3yZhiG3ahrSBh956thMTHrKTEwAfJIEXI4cuSVWQAaOJ4Em5SDFxQe
d0E6Ui2nGh1SFGF7oyuEXyzqgRMWFNDFw9HLUNgXaO18Zfouw8+K0BgbfEWEcSi1
JLQbyhCjz088gltrliQGPWDFAg9cHBKtJhuTzZkvuqK1CLEmBhtzP1zFiGBfOJc8
v+aKjAwrPUNX11cXOCPxBv2qXMetxaovBem6AI2hvypCInXaVQfP+yOLubzlTDjS
Y708SlY38hmS1uTwDpyLOn8AKkZ8jtx75g==
-----END CERTIFICATE-----`
// ecdsaPrivateKeyWithParams is a ECDSA Private Key with included EC Parameters block
// openssl ecparam -name prime256v1 -genkey -out ecdsa256params.pem
ecdsaPrivateKeyWithParams = `-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIAwSOWQqlMTZNqNF7tgua812Jxib1DVOgb2pHHyIEyNNoAoGCCqGSM49
AwEHoUQDQgAEyxYNrs6a6tsNCFNYn+l+JDUZ0PnUZbcsDgJn2O62D1se8M5iQ5rY
iIv6RpxE3VHvlHEIvYgCZkG0jHszTUopBg==
-----END EC PRIVATE KEY-----`
// ecdsaPrivateKey is a ECDSA Private Key in ASN.1 format
// openssl ecparam -name prime256v1 -genkey -noout -out ecdsa256.pem
ecdsaPrivateKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIP6Qw6dHDiLsSnLXUhQVTPE0fTQQrj3XSbiQAZPXnk5+oAoGCCqGSM49
AwEHoUQDQgAEZZzi1u5f2/AEGFI/HYUhU+u6cTK1q2bbtE7r1JMK+/sQA5sNAp+7
Vdc3psr1OaNzyTyuhTECyRdFKXm63cMnGg==
-----END EC PRIVATE KEY-----`
// ecdsaPublicKey is a ECDSA Public Key in PEM encoded format
// openssl ec -in ecdsa256.pem -pubout -out ecdsa256pub.pem
ecdsaPublicKey = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEZZzi1u5f2/AEGFI/HYUhU+u6cTK1
q2bbtE7r1JMK+/sQA5sNAp+7Vdc3psr1OaNzyTyuhTECyRdFKXm63cMnGg==
-----END PUBLIC KEY-----`
)
func TestReadPrivateKey(t *testing.T) {
f, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating tmpfile: %v", err)
}
defer os.Remove(f.Name())
if _, err := PrivateKeyFromFile(f.Name()); err == nil {
t.Fatalf("Expected error reading key from empty file, got none")
}
if err := ioutil.WriteFile(f.Name(), []byte(rsaPrivateKey), os.FileMode(0600)); err != nil {
t.Fatalf("error writing private key to tmpfile: %v", err)
}
if _, err := PrivateKeyFromFile(f.Name()); err != nil {
t.Fatalf("error reading private RSA key: %v", err)
}
if err := ioutil.WriteFile(f.Name(), []byte(ecdsaPrivateKey), os.FileMode(0600)); err != nil {
t.Fatalf("error writing private key to tmpfile: %v", err)
}
if _, err := PrivateKeyFromFile(f.Name()); err != nil {
t.Fatalf("error reading private ECDSA key: %v", err)
}
if err := ioutil.WriteFile(f.Name(), []byte(ecdsaPrivateKeyWithParams), os.FileMode(0600)); err != nil {
t.Fatalf("error writing private key to tmpfile: %v", err)
}
if _, err := PrivateKeyFromFile(f.Name()); err != nil {
t.Fatalf("error reading private ECDSA key with params: %v", err)
}
}
func TestReadPublicKeys(t *testing.T) {
f, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating tmpfile: %v", err)
}
defer os.Remove(f.Name())
if _, err := PublicKeysFromFile(f.Name()); err == nil {
t.Fatalf("Expected error reading keys from empty file, got none")
}
if err := ioutil.WriteFile(f.Name(), []byte(rsaPublicKey), os.FileMode(0600)); err != nil {
t.Fatalf("error writing public key to tmpfile: %v", err)
}
if keys, err := PublicKeysFromFile(f.Name()); err != nil {
t.Fatalf("error reading RSA public key: %v", err)
} else if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if err := ioutil.WriteFile(f.Name(), []byte(ecdsaPublicKey), os.FileMode(0600)); err != nil {
t.Fatalf("error writing public key to tmpfile: %v", err)
}
if keys, err := PublicKeysFromFile(f.Name()); err != nil {
t.Fatalf("error reading ECDSA public key: %v", err)
} else if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if err := ioutil.WriteFile(f.Name(), []byte(rsaPublicKey+"\n"+ecdsaPublicKey), os.FileMode(0600)); err != nil {
t.Fatalf("error writing public key to tmpfile: %v", err)
}
if keys, err := PublicKeysFromFile(f.Name()); err != nil {
t.Fatalf("error reading combined RSA/ECDSA public key file: %v", err)
} else if len(keys) != 2 {
t.Fatalf("expected 2 keys, got %d", len(keys))
}
if err := ioutil.WriteFile(f.Name(), []byte(certificate), os.FileMode(0600)); err != nil {
t.Fatalf("error writing certificate to tmpfile: %v", err)
}
if keys, err := PublicKeysFromFile(f.Name()); err != nil {
t.Fatalf("error reading public key from certificate file: %v", err)
} else if len(keys) != 1 {
t.Fatalf("expected 1 keys, got %d", len(keys))
}
}

View File

@ -1,6 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAPEbSXwyDfWf0+61Oofd7aHkmdX69mrzD2Xb1CHF5syfsoRIhnG0dJ
ozBulPZCDDWgBwYFK4EEACKhZANiAATjlMJAtKhEPqU/i7MsrgKcK/RmXHC6He7W
0p69+9qFXg2raJ9zvvbKxkiu2ELOYRDAz0utcFTBOIgoUJEzBVmsjZQ7dvFa1BKP
Ym7MFAKG3O2espBqXn+audgdHGh5B0I=
-----END EC PRIVATE KEY-----

View File

@ -1,26 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["triple.go"],
importpath = "k8s.io/client-go/util/cert/triple",
deps = ["//vendor/k8s.io/client-go/util/cert:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,116 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package triple generates key-certificate pairs for the
// triple (CA, Server, Client).
package triple
import (
"crypto/rsa"
"crypto/x509"
"fmt"
"net"
certutil "k8s.io/client-go/util/cert"
)
type KeyPair struct {
Key *rsa.PrivateKey
Cert *x509.Certificate
}
func NewCA(name string) (*KeyPair, error) {
key, err := certutil.NewPrivateKey()
if err != nil {
return nil, fmt.Errorf("unable to create a private key for a new CA: %v", err)
}
config := certutil.Config{
CommonName: name,
}
cert, err := certutil.NewSelfSignedCACert(config, key)
if err != nil {
return nil, fmt.Errorf("unable to create a self-signed certificate for a new CA: %v", err)
}
return &KeyPair{
Key: key,
Cert: cert,
}, nil
}
func NewServerKeyPair(ca *KeyPair, commonName, svcName, svcNamespace, dnsDomain string, ips, hostnames []string) (*KeyPair, error) {
key, err := certutil.NewPrivateKey()
if err != nil {
return nil, fmt.Errorf("unable to create a server private key: %v", err)
}
namespacedName := fmt.Sprintf("%s.%s", svcName, svcNamespace)
internalAPIServerFQDN := []string{
svcName,
namespacedName,
fmt.Sprintf("%s.svc", namespacedName),
fmt.Sprintf("%s.svc.%s", namespacedName, dnsDomain),
}
altNames := certutil.AltNames{}
for _, ipStr := range ips {
ip := net.ParseIP(ipStr)
if ip != nil {
altNames.IPs = append(altNames.IPs, ip)
}
}
altNames.DNSNames = append(altNames.DNSNames, hostnames...)
altNames.DNSNames = append(altNames.DNSNames, internalAPIServerFQDN...)
config := certutil.Config{
CommonName: commonName,
AltNames: altNames,
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
cert, err := certutil.NewSignedCert(config, key, ca.Cert, ca.Key)
if err != nil {
return nil, fmt.Errorf("unable to sign the server certificate: %v", err)
}
return &KeyPair{
Key: key,
Cert: cert,
}, nil
}
func NewClientKeyPair(ca *KeyPair, commonName string, organizations []string) (*KeyPair, error) {
key, err := certutil.NewPrivateKey()
if err != nil {
return nil, fmt.Errorf("unable to create a client private key: %v", err)
}
config := certutil.Config{
CommonName: commonName,
Organization: organizations,
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
cert, err := certutil.NewSignedCert(config, key, ca.Cert, ca.Key)
if err != nil {
return nil, fmt.Errorf("unable to sign the client certificate: %v", err)
}
return &KeyPair{
Key: key,
Cert: cert,
}, nil
}

View File

@ -1,67 +0,0 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"certificate_manager_test.go",
"certificate_store_test.go",
],
importpath = "k8s.io/client-go/util/certificate",
library = ":go_default_library",
tags = ["automanaged"],
deps = [
"//vendor/k8s.io/api/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/runtime/schema:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/watch:go_default_library",
"//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/client-go/util/cert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"certificate_manager.go",
"certificate_store.go",
],
importpath = "k8s.io/client-go/util/certificate",
tags = ["automanaged"],
deps = [
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/k8s.io/api/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/runtime:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
"//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/client-go/util/cert:go_default_library",
"//vendor/k8s.io/client-go/util/certificate/csr:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//staging/src/k8s.io/client-go/util/certificate/csr:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@ -1,8 +0,0 @@
reviewers:
- mikedanese
- liggit
- smarterclayton
approvers:
- mikedanese
- liggit
- smarterclayton

View File

@ -1,440 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package certificate
import (
"crypto/ecdsa"
"crypto/elliptic"
cryptorand "crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"
"time"
"github.com/golang/glog"
certificates "k8s.io/api/certificates/v1beta1"
"k8s.io/apimachinery/pkg/api/errors"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1"
"k8s.io/client-go/util/cert"
"k8s.io/client-go/util/certificate/csr"
)
// certificateWaitBackoff controls the amount and timing of retries when the
// watch for certificate approval is interrupted.
var certificateWaitBackoff = wait.Backoff{Duration: 30 * time.Second, Steps: 4, Factor: 1.5, Jitter: 0.1}
// Manager maintains and updates the certificates in use by this certificate
// manager. In the background it communicates with the API server to get new
// certificates for certificates about to expire.
type Manager interface {
// CertificateSigningRequestClient sets the client interface that is used for
// signing new certificates generated as part of rotation.
SetCertificateSigningRequestClient(certificatesclient.CertificateSigningRequestInterface) error
// Start the API server status sync loop.
Start()
// Current returns the currently selected certificate from the
// certificate manager, as well as the associated certificate and key data
// in PEM format.
Current() *tls.Certificate
// ServerHealthy returns true if the manager is able to communicate with
// the server. This allows a caller to determine whether the cert manager
// thinks it can potentially talk to the API server. The cert manager may
// be very conservative and only return true if recent communication has
// occurred with the server.
ServerHealthy() bool
}
// Config is the set of configuration parameters available for a new Manager.
type Config struct {
// CertificateSigningRequestClient will be used for signing new certificate
// requests generated when a key rotation occurs. It must be set either at
// initialization or by using CertificateSigningRequestClient before
// Manager.Start() is called.
CertificateSigningRequestClient certificatesclient.CertificateSigningRequestInterface
// Template is the CertificateRequest that will be used as a template for
// generating certificate signing requests for all new keys generated as
// part of rotation. It follows the same rules as the template parameter of
// crypto.x509.CreateCertificateRequest in the Go standard libraries.
Template *x509.CertificateRequest
// Usages is the types of usages that certificates generated by the manager
// can be used for.
Usages []certificates.KeyUsage
// CertificateStore is a persistent store where the current cert/key is
// kept and future cert/key pairs will be persisted after they are
// generated.
CertificateStore Store
// BootstrapCertificatePEM is the certificate data that will be returned
// from the Manager if the CertificateStore doesn't have any cert/key pairs
// currently available and has not yet had a chance to get a new cert/key
// pair from the API. If the CertificateStore does have a cert/key pair,
// this will be ignored. If there is no cert/key pair available in the
// CertificateStore, as soon as Start is called, it will request a new
// cert/key pair from the CertificateSigningRequestClient. This is intended
// to allow the first boot of a component to be initialized using a
// generic, multi-use cert/key pair which will be quickly replaced with a
// unique cert/key pair.
BootstrapCertificatePEM []byte
// BootstrapKeyPEM is the key data that will be returned from the Manager
// if the CertificateStore doesn't have any cert/key pairs currently
// available. If the CertificateStore does have a cert/key pair, this will
// be ignored. If the bootstrap cert/key pair are used, they will be
// rotated at the first opportunity, possibly well in advance of expiring.
// This is intended to allow the first boot of a component to be
// initialized using a generic, multi-use cert/key pair which will be
// quickly replaced with a unique cert/key pair.
BootstrapKeyPEM []byte
// CertificateExpiration will record a metric that shows the remaining
// lifetime of the certificate.
CertificateExpiration Gauge
}
// Store is responsible for getting and updating the current certificate.
// Depending on the concrete implementation, the backing store for this
// behavior may vary.
type Store interface {
// Current returns the currently selected certificate, as well as the
// associated certificate and key data in PEM format. If the Store doesn't
// have a cert/key pair currently, it should return a NoCertKeyError so
// that the Manager can recover by using bootstrap certificates to request
// a new cert/key pair.
Current() (*tls.Certificate, error)
// Update accepts the PEM data for the cert/key pair and makes the new
// cert/key pair the 'current' pair, that will be returned by future calls
// to Current().
Update(cert, key []byte) (*tls.Certificate, error)
}
// Gauge will record the remaining lifetime of the certificate each time it is
// updated.
type Gauge interface {
Set(float64)
}
// NoCertKeyError indicates there is no cert/key currently available.
type NoCertKeyError string
func (e *NoCertKeyError) Error() string { return string(*e) }
type manager struct {
certSigningRequestClient certificatesclient.CertificateSigningRequestInterface
template *x509.CertificateRequest
usages []certificates.KeyUsage
certStore Store
certAccessLock sync.RWMutex
cert *tls.Certificate
rotationDeadline time.Time
forceRotation bool
certificateExpiration Gauge
serverHealth bool
}
// NewManager returns a new certificate manager. A certificate manager is
// responsible for being the authoritative source of certificates in the
// Kubelet and handling updates due to rotation.
func NewManager(config *Config) (Manager, error) {
cert, forceRotation, err := getCurrentCertificateOrBootstrap(
config.CertificateStore,
config.BootstrapCertificatePEM,
config.BootstrapKeyPEM)
if err != nil {
return nil, err
}
m := manager{
certSigningRequestClient: config.CertificateSigningRequestClient,
template: config.Template,
usages: config.Usages,
certStore: config.CertificateStore,
cert: cert,
forceRotation: forceRotation,
certificateExpiration: config.CertificateExpiration,
}
return &m, nil
}
// Current returns the currently selected certificate from the certificate
// manager. This can be nil if the manager was initialized without a
// certificate and has not yet received one from the
// CertificateSigningRequestClient.
func (m *manager) Current() *tls.Certificate {
m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock()
return m.cert
}
// ServerHealthy returns true if the cert manager believes the server
// is currently alive.
func (m *manager) ServerHealthy() bool {
m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock()
return m.serverHealth
}
// SetCertificateSigningRequestClient sets the client interface that is used
// for signing new certificates generated as part of rotation. It must be
// called before Start() and can not be used to change the
// CertificateSigningRequestClient that has already been set. This method is to
// support the one specific scenario where the CertificateSigningRequestClient
// uses the CertificateManager.
func (m *manager) SetCertificateSigningRequestClient(certSigningRequestClient certificatesclient.CertificateSigningRequestInterface) error {
if m.certSigningRequestClient == nil {
m.certSigningRequestClient = certSigningRequestClient
return nil
}
return fmt.Errorf("property CertificateSigningRequestClient is already set")
}
// Start will start the background work of rotating the certificates.
func (m *manager) Start() {
// Certificate rotation depends on access to the API server certificate
// signing API, so don't start the certificate manager if we don't have a
// client. This will happen on the cluster master, where the kubelet is
// responsible for bootstrapping the pods of the master components.
if m.certSigningRequestClient == nil {
glog.V(2).Infof("Certificate rotation is not enabled, no connection to the apiserver.")
return
}
glog.V(2).Infof("Certificate rotation is enabled.")
m.setRotationDeadline()
// Synchronously request a certificate before entering the background
// loop to allow bootstrap scenarios, where the certificate manager
// doesn't have a certificate at all yet.
if m.shouldRotate() {
glog.V(1).Infof("shouldRotate() is true, forcing immediate rotation")
if _, err := m.rotateCerts(); err != nil {
utilruntime.HandleError(fmt.Errorf("Could not rotate certificates: %v", err))
}
}
backoff := wait.Backoff{
Duration: 2 * time.Second,
Factor: 2,
Jitter: 0.1,
Steps: 5,
}
go wait.Forever(func() {
sleepInterval := m.rotationDeadline.Sub(time.Now())
glog.V(2).Infof("Waiting %v for next certificate rotation", sleepInterval)
time.Sleep(sleepInterval)
if err := wait.ExponentialBackoff(backoff, m.rotateCerts); err != nil {
utilruntime.HandleError(fmt.Errorf("Reached backoff limit, still unable to rotate certs: %v", err))
wait.PollInfinite(32*time.Second, m.rotateCerts)
}
}, 0)
}
func getCurrentCertificateOrBootstrap(
store Store,
bootstrapCertificatePEM []byte,
bootstrapKeyPEM []byte) (cert *tls.Certificate, shouldRotate bool, errResult error) {
currentCert, err := store.Current()
if err == nil {
return currentCert, false, nil
}
if _, ok := err.(*NoCertKeyError); !ok {
return nil, false, err
}
if bootstrapCertificatePEM == nil || bootstrapKeyPEM == nil {
return nil, true, nil
}
bootstrapCert, err := tls.X509KeyPair(bootstrapCertificatePEM, bootstrapKeyPEM)
if err != nil {
return nil, false, err
}
if len(bootstrapCert.Certificate) < 1 {
return nil, false, fmt.Errorf("no cert/key data found")
}
certs, err := x509.ParseCertificates(bootstrapCert.Certificate[0])
if err != nil {
return nil, false, fmt.Errorf("unable to parse certificate data: %v", err)
}
bootstrapCert.Leaf = certs[0]
return &bootstrapCert, true, nil
}
// shouldRotate looks at how close the current certificate is to expiring and
// decides if it is time to rotate or not.
func (m *manager) shouldRotate() bool {
m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock()
if m.cert == nil {
return true
}
if m.forceRotation {
return true
}
return time.Now().After(m.rotationDeadline)
}
// rotateCerts attempts to request a client cert from the server, wait a reasonable
// period of time for it to be signed, and then update the cert on disk. If it cannot
// retrieve a cert, it will return false. It will only return error in exceptional cases.
// This method also keeps track of "server health" by interpreting the responses it gets
// from the server on the various calls it makes.
func (m *manager) rotateCerts() (bool, error) {
glog.V(2).Infof("Rotating certificates")
csrPEM, keyPEM, privateKey, err := m.generateCSR()
if err != nil {
utilruntime.HandleError(fmt.Errorf("Unable to generate a certificate signing request: %v", err))
return false, nil
}
// Call the Certificate Signing Request API to get a certificate for the
// new private key.
req, err := csr.RequestCertificate(m.certSigningRequestClient, csrPEM, "", m.usages, privateKey)
if err != nil {
utilruntime.HandleError(fmt.Errorf("Failed while requesting a signed certificate from the master: %v", err))
return false, m.updateServerError(err)
}
// Wait for the certificate to be signed. Instead of one long watch, we retry with slighly longer
// intervals each time in order to tolerate failures from the server AND to preserve the liveliness
// of the cert manager loop. This creates slightly more traffic against the API server in return
// for bounding the amount of time we wait when a certificate expires.
var crtPEM []byte
watchDuration := time.Minute
if err := wait.ExponentialBackoff(certificateWaitBackoff, func() (bool, error) {
data, err := csr.WaitForCertificate(m.certSigningRequestClient, req, watchDuration)
switch {
case err == nil:
crtPEM = data
return true, nil
case err == wait.ErrWaitTimeout:
watchDuration += time.Minute
if watchDuration > 5*time.Minute {
watchDuration = 5 * time.Minute
}
return false, nil
default:
utilruntime.HandleError(fmt.Errorf("Unable to check certificate signing status: %v", err))
return false, m.updateServerError(err)
}
}); err != nil {
utilruntime.HandleError(fmt.Errorf("Certificate request was not signed: %v", err))
return false, nil
}
cert, err := m.certStore.Update(crtPEM, keyPEM)
if err != nil {
utilruntime.HandleError(fmt.Errorf("Unable to store the new cert/key pair: %v", err))
return false, nil
}
m.updateCached(cert)
m.setRotationDeadline()
m.forceRotation = false
return true, nil
}
// setRotationDeadline sets a cached value for the threshold at which the
// current certificate should be rotated, 80%+/-10% of the expiration of the
// certificate.
func (m *manager) setRotationDeadline() {
m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock()
if m.cert == nil {
m.rotationDeadline = time.Now()
return
}
notAfter := m.cert.Leaf.NotAfter
totalDuration := float64(notAfter.Sub(m.cert.Leaf.NotBefore))
m.rotationDeadline = m.cert.Leaf.NotBefore.Add(jitteryDuration(totalDuration))
glog.V(2).Infof("Certificate expiration is %v, rotation deadline is %v", notAfter, m.rotationDeadline)
if m.certificateExpiration != nil {
m.certificateExpiration.Set(float64(notAfter.Unix()))
}
}
// jitteryDuration uses some jitter to set the rotation threshold so each node
// will rotate at approximately 70-90% of the total lifetime of the
// certificate. With jitter, if a number of nodes are added to a cluster at
// approximately the same time (such as cluster creation time), they won't all
// try to rotate certificates at the same time for the rest of the life of the
// cluster.
//
// This function is represented as a variable to allow replacement during testing.
var jitteryDuration = func(totalDuration float64) time.Duration {
return wait.Jitter(time.Duration(totalDuration), 0.2) - time.Duration(totalDuration*0.3)
}
// updateCached sets the most recent retrieved cert. It also sets the server
// as assumed healthy.
func (m *manager) updateCached(cert *tls.Certificate) {
m.certAccessLock.Lock()
defer m.certAccessLock.Unlock()
m.serverHealth = true
m.cert = cert
}
// updateServerError takes an error returned by the server and infers
// the health of the server based on the error. It will return nil if
// the error does not require immediate termination of any wait loops,
// and otherwise it will return the error.
func (m *manager) updateServerError(err error) error {
m.certAccessLock.Lock()
defer m.certAccessLock.Unlock()
switch {
case errors.IsUnauthorized(err):
// SSL terminating proxies may report this error instead of the master
m.serverHealth = true
case errors.IsUnexpectedServerError(err):
// generally indicates a proxy or other load balancer problem, rather than a problem coming
// from the master
m.serverHealth = false
default:
// Identify known errors that could be expected for a cert request that
// indicate everything is working normally
m.serverHealth = errors.IsNotFound(err) || errors.IsForbidden(err)
}
return nil
}
func (m *manager) generateCSR() (csrPEM []byte, keyPEM []byte, key interface{}, err error) {
// Generate a new private key.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), cryptorand.Reader)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to generate a new private key: %v", err)
}
der, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to marshal the new key to DER: %v", err)
}
keyPEM = pem.EncodeToMemory(&pem.Block{Type: cert.ECPrivateKeyBlockType, Bytes: der})
csrPEM, err = cert.MakeCSRFromTemplate(privateKey, m.template)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create a csr from the private key: %v", err)
}
return csrPEM, keyPEM, privateKey, nil
}

View File

@ -1,953 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package certificate
import (
"bytes"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"strings"
"testing"
"time"
certificates "k8s.io/api/certificates/v1beta1"
"k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/wait"
watch "k8s.io/apimachinery/pkg/watch"
certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1"
)
var storeCertData = newCertificateData(`-----BEGIN CERTIFICATE-----
MIICRzCCAfGgAwIBAgIJALMb7ecMIk3MMA0GCSqGSIb3DQEBCwUAMH4xCzAJBgNV
BAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNVBAcMBkxvbmRvbjEYMBYGA1UE
CgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1JVCBEZXBhcnRtZW50MRswGQYD
VQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTAwIBcNMTcwNDI2MjMyNjUyWhgPMjExNzA0
MDIyMzI2NTJaMH4xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNV
BAcMBkxvbmRvbjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1J
VCBEZXBhcnRtZW50MRswGQYDVQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTAwXDANBgkq
hkiG9w0BAQEFAANLADBIAkEAtBMa7NWpv3BVlKTCPGO/LEsguKqWHBtKzweMY2CV
tAL1rQm913huhxF9w+ai76KQ3MHK5IVnLJjYYA5MzP2H5QIDAQABo1AwTjAdBgNV
HQ4EFgQU22iy8aWkNSxv0nBxFxerfsvnZVMwHwYDVR0jBBgwFoAU22iy8aWkNSxv
0nBxFxerfsvnZVMwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAANBAEOefGbV
NcHxklaW06w6OBYJPwpIhCVozC1qdxGX1dg8VkEKzjOzjgqVD30m59OFmSlBmHsl
nkVA6wyOSDYBf3o=
-----END CERTIFICATE-----`, `-----BEGIN RSA PRIVATE KEY-----
MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAtBMa7NWpv3BVlKTC
PGO/LEsguKqWHBtKzweMY2CVtAL1rQm913huhxF9w+ai76KQ3MHK5IVnLJjYYA5M
zP2H5QIDAQABAkAS9BfXab3OKpK3bIgNNyp+DQJKrZnTJ4Q+OjsqkpXvNltPJosf
G8GsiKu/vAt4HGqI3eU77NvRI+mL4MnHRmXBAiEA3qM4FAtKSRBbcJzPxxLEUSwg
XSCcosCktbkXvpYrS30CIQDPDxgqlwDEJQ0uKuHkZI38/SPWWqfUmkecwlbpXABK
iQIgZX08DA8VfvcA5/Xj1Zjdey9FVY6POLXen6RPiabE97UCICp6eUW7ht+2jjar
e35EltCRCjoejRHTuN9TC0uCoVipAiAXaJIx/Q47vGwiw6Y8KXsNU6y54gTbOSxX
54LzHNk/+Q==
-----END RSA PRIVATE KEY-----`)
var bootstrapCertData = newCertificateData(
`-----BEGIN CERTIFICATE-----
MIICRzCCAfGgAwIBAgIJANXr+UzRFq4TMA0GCSqGSIb3DQEBCwUAMH4xCzAJBgNV
BAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNVBAcMBkxvbmRvbjEYMBYGA1UE
CgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1JVCBEZXBhcnRtZW50MRswGQYD
VQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTEwIBcNMTcwNDI2MjMyNzMyWhgPMjExNzA0
MDIyMzI3MzJaMH4xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNV
BAcMBkxvbmRvbjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1J
VCBEZXBhcnRtZW50MRswGQYDVQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTEwXDANBgkq
hkiG9w0BAQEFAANLADBIAkEAqvbkN4RShH1rL37JFp4fZPnn0JUhVWWsrP8NOomJ
pXdBDUMGWuEQIsZ1Gf9JrCQLu6ooRyHSKRFpAVbMQ3ABJwIDAQABo1AwTjAdBgNV
HQ4EFgQUEGBc6YYheEZ/5MhwqSUYYPYRj2MwHwYDVR0jBBgwFoAUEGBc6YYheEZ/
5MhwqSUYYPYRj2MwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAANBAIyNmznk
5dgJY52FppEEcfQRdS5k4XFPc22SHPcz77AHf5oWZ1WG9VezOZZPp8NCiFDDlDL8
yma33a5eMyTjLD8=
-----END CERTIFICATE-----`, `-----BEGIN RSA PRIVATE KEY-----
MIIBVAIBADANBgkqhkiG9w0BAQEFAASCAT4wggE6AgEAAkEAqvbkN4RShH1rL37J
Fp4fZPnn0JUhVWWsrP8NOomJpXdBDUMGWuEQIsZ1Gf9JrCQLu6ooRyHSKRFpAVbM
Q3ABJwIDAQABAkBC2OBpGLMPHN8BJijIUDFkURakBvuOoX+/8MYiYk7QxEmfLCk6
L6r+GLNFMfXwXcBmXtMKfZKAIKutKf098JaBAiEA10azfqt3G/5owrNA00plSyT6
ZmHPzY9Uq1p/QTR/uOcCIQDLTkfBkLHm0UKeobbO/fSm6ZflhyBRDINy4FvwmZMt
wQIgYV/tmQJeIh91q3wBepFQOClFykG8CTMoDUol/YyNqUkCIHfp6Rr7fGL3JIMq
QQgf9DCK8SPZqq8DYXjdan0kKBJBAiEAyDb+07o2gpggo8BYUKSaiRCiyXfaq87f
eVqgpBq/QN4=
-----END RSA PRIVATE KEY-----`)
var apiServerCertData = newCertificateData(
`-----BEGIN CERTIFICATE-----
MIICRzCCAfGgAwIBAgIJAIydTIADd+yqMA0GCSqGSIb3DQEBCwUAMH4xCzAJBgNV
BAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNVBAcMBkxvbmRvbjEYMBYGA1UE
CgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1JVCBEZXBhcnRtZW50MRswGQYD
VQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTIwIBcNMTcwNDI2MjMyNDU4WhgPMjExNzA0
MDIyMzI0NThaMH4xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNV
BAcMBkxvbmRvbjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1J
VCBEZXBhcnRtZW50MRswGQYDVQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTIwXDANBgkq
hkiG9w0BAQEFAANLADBIAkEAuiRet28DV68Dk4A8eqCaqgXmymamUEjW/DxvIQqH
3lbhtm8BwSnS9wUAajSLSWiq3fci2RbRgaSPjUrnbOHCLQIDAQABo1AwTjAdBgNV
HQ4EFgQU0vhI4OPGEOqT+VAWwxdhVvcmgdIwHwYDVR0jBBgwFoAU0vhI4OPGEOqT
+VAWwxdhVvcmgdIwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAANBALNeJGDe
nV5cXbp9W1bC12Tc8nnNXn4ypLE2JTQAvyp51zoZ8hQoSnRVx/VCY55Yu+br8gQZ
+tW+O/PoE7B3tuY=
-----END CERTIFICATE-----`, `-----BEGIN RSA PRIVATE KEY-----
MIIBVgIBADANBgkqhkiG9w0BAQEFAASCAUAwggE8AgEAAkEAuiRet28DV68Dk4A8
eqCaqgXmymamUEjW/DxvIQqH3lbhtm8BwSnS9wUAajSLSWiq3fci2RbRgaSPjUrn
bOHCLQIDAQABAkEArDR1g9IqD3aUImNikDgAngbzqpAokOGyMoxeavzpEaFOgCzi
gi7HF7yHRmZkUt8CzdEvnHSqRjFuaaB0gGA+AQIhAOc8Z1h8ElLRSqaZGgI3jCTp
Izx9HNY//U5NGrXD2+ttAiEAzhOqkqI4+nDab7FpiD7MXI6fO549mEXeVBPvPtsS
OcECIQCIfkpOm+ZBBpO3JXaJynoqK4gGI6ALA/ik6LSUiIlfPQIhAISjd9hlfZME
bDQT1r8Q3Gx+h9LRqQeHgPBQ3F5ylqqBAiBaJ0hkYvrIdWxNlcLqD3065bJpHQ4S
WQkuZUQN1M/Xvg==
-----END RSA PRIVATE KEY-----`)
type certificateData struct {
keyPEM []byte
certificatePEM []byte
certificate *tls.Certificate
}
func newCertificateData(certificatePEM string, keyPEM string) *certificateData {
certificate, err := tls.X509KeyPair([]byte(certificatePEM), []byte(keyPEM))
if err != nil {
panic(fmt.Sprintf("Unable to initialize certificate: %v", err))
}
certs, err := x509.ParseCertificates(certificate.Certificate[0])
if err != nil {
panic(fmt.Sprintf("Unable to initialize certificate leaf: %v", err))
}
certificate.Leaf = certs[0]
return &certificateData{
keyPEM: []byte(keyPEM),
certificatePEM: []byte(certificatePEM),
certificate: &certificate,
}
}
func TestNewManagerNoRotation(t *testing.T) {
store := &fakeStore{
cert: storeCertData.certificate,
}
if _, err := NewManager(&Config{
Template: &x509.CertificateRequest{},
Usages: []certificates.KeyUsage{},
CertificateStore: store,
}); err != nil {
t.Fatalf("Failed to initialize the certificate manager: %v", err)
}
}
func TestShouldRotate(t *testing.T) {
now := time.Now()
tests := []struct {
name string
notBefore time.Time
notAfter time.Time
shouldRotate bool
}{
{"just issued, still good", now.Add(-1 * time.Hour), now.Add(99 * time.Hour), false},
{"half way expired, still good", now.Add(-24 * time.Hour), now.Add(24 * time.Hour), false},
{"mostly expired, still good", now.Add(-69 * time.Hour), now.Add(31 * time.Hour), false},
{"just about expired, should rotate", now.Add(-91 * time.Hour), now.Add(9 * time.Hour), true},
{"nearly expired, should rotate", now.Add(-99 * time.Hour), now.Add(1 * time.Hour), true},
{"already expired, should rotate", now.Add(-10 * time.Hour), now.Add(-1 * time.Hour), true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := manager{
cert: &tls.Certificate{
Leaf: &x509.Certificate{
NotBefore: test.notBefore,
NotAfter: test.notAfter,
},
},
template: &x509.CertificateRequest{},
usages: []certificates.KeyUsage{},
}
m.setRotationDeadline()
if m.shouldRotate() != test.shouldRotate {
t.Errorf("Time %v, a certificate issued for (%v, %v) should rotate should be %t.",
now,
m.cert.Leaf.NotBefore,
m.cert.Leaf.NotAfter,
test.shouldRotate)
}
})
}
}
type gaugeMock struct {
calls int
lastValue float64
}
func (g *gaugeMock) Set(v float64) {
g.calls++
g.lastValue = v
}
func TestSetRotationDeadline(t *testing.T) {
defer func(original func(float64) time.Duration) { jitteryDuration = original }(jitteryDuration)
now := time.Now()
testCases := []struct {
name string
notBefore time.Time
notAfter time.Time
shouldRotate bool
}{
{"just issued, still good", now.Add(-1 * time.Hour), now.Add(99 * time.Hour), false},
{"half way expired, still good", now.Add(-24 * time.Hour), now.Add(24 * time.Hour), false},
{"mostly expired, still good", now.Add(-69 * time.Hour), now.Add(31 * time.Hour), false},
{"just about expired, should rotate", now.Add(-91 * time.Hour), now.Add(9 * time.Hour), true},
{"nearly expired, should rotate", now.Add(-99 * time.Hour), now.Add(1 * time.Hour), true},
{"already expired, should rotate", now.Add(-10 * time.Hour), now.Add(-1 * time.Hour), true},
{"long duration", now.Add(-6 * 30 * 24 * time.Hour), now.Add(6 * 30 * 24 * time.Hour), true},
{"short duration", now.Add(-30 * time.Second), now.Add(30 * time.Second), true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := gaugeMock{}
m := manager{
cert: &tls.Certificate{
Leaf: &x509.Certificate{
NotBefore: tc.notBefore,
NotAfter: tc.notAfter,
},
},
template: &x509.CertificateRequest{},
usages: []certificates.KeyUsage{},
certificateExpiration: &g,
}
jitteryDuration = func(float64) time.Duration { return time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7) }
lowerBound := tc.notBefore.Add(time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7))
m.setRotationDeadline()
if !m.rotationDeadline.Equal(lowerBound) {
t.Errorf("For notBefore %v, notAfter %v, the rotationDeadline %v should be %v.",
tc.notBefore,
tc.notAfter,
m.rotationDeadline,
lowerBound)
}
if g.calls != 1 {
t.Errorf("%d metrics were recorded, wanted %d", g.calls, 1)
}
if g.lastValue != float64(tc.notAfter.Unix()) {
t.Errorf("%d value for metric was recorded, wanted %d", g.lastValue, tc.notAfter.Unix())
}
})
}
}
func TestRotateCertCreateCSRError(t *testing.T) {
now := time.Now()
m := manager{
cert: &tls.Certificate{
Leaf: &x509.Certificate{
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(-1 * time.Hour),
},
},
template: &x509.CertificateRequest{},
usages: []certificates.KeyUsage{},
certSigningRequestClient: fakeClient{
failureType: createError,
},
}
if success, err := m.rotateCerts(); success {
t.Errorf("Got success from 'rotateCerts', wanted failure")
} else if err != nil {
t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err)
}
}
func TestRotateCertWaitingForResultError(t *testing.T) {
now := time.Now()
m := manager{
cert: &tls.Certificate{
Leaf: &x509.Certificate{
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(-1 * time.Hour),
},
},
template: &x509.CertificateRequest{},
usages: []certificates.KeyUsage{},
certSigningRequestClient: fakeClient{
failureType: watchError,
},
}
certificateWaitBackoff = wait.Backoff{Steps: 1}
if success, err := m.rotateCerts(); success {
t.Errorf("Got success from 'rotateCerts', wanted failure.")
} else if err != nil {
t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err)
}
}
func TestNewManagerBootstrap(t *testing.T) {
store := &fakeStore{}
var cm Manager
cm, err := NewManager(&Config{
Template: &x509.CertificateRequest{},
Usages: []certificates.KeyUsage{},
CertificateStore: store,
BootstrapCertificatePEM: bootstrapCertData.certificatePEM,
BootstrapKeyPEM: bootstrapCertData.keyPEM,
})
if err != nil {
t.Fatalf("Failed to initialize the certificate manager: %v", err)
}
cert := cm.Current()
if cert == nil {
t.Errorf("Certificate was nil, expected something.")
}
if m, ok := cm.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'")
} else if !m.shouldRotate() {
t.Errorf("Expected rotation should happen during bootstrap, but it won't.")
}
}
func TestNewManagerNoBootstrap(t *testing.T) {
now := time.Now()
cert, err := tls.X509KeyPair(storeCertData.certificatePEM, storeCertData.keyPEM)
if err != nil {
t.Fatalf("Unable to initialize a certificate: %v", err)
}
cert.Leaf = &x509.Certificate{
NotBefore: now.Add(-24 * time.Hour),
NotAfter: now.Add(24 * time.Hour),
}
store := &fakeStore{
cert: &cert,
}
cm, err := NewManager(&Config{
Template: &x509.CertificateRequest{},
Usages: []certificates.KeyUsage{},
CertificateStore: store,
BootstrapCertificatePEM: bootstrapCertData.certificatePEM,
BootstrapKeyPEM: bootstrapCertData.keyPEM,
})
if err != nil {
t.Fatalf("Failed to initialize the certificate manager: %v", err)
}
currentCert := cm.Current()
if currentCert == nil {
t.Errorf("Certificate was nil, expected something.")
}
if m, ok := cm.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'")
} else {
m.setRotationDeadline()
if m.shouldRotate() {
t.Errorf("Expected rotation should happen during bootstrap, but it won't.")
}
}
}
func TestGetCurrentCertificateOrBootstrap(t *testing.T) {
testCases := []struct {
description string
storeCert *tls.Certificate
bootstrapCertData []byte
bootstrapKeyData []byte
expectedCert *tls.Certificate
expectedShouldRotate bool
expectedErrMsg string
}{
{
"return cert from store",
storeCertData.certificate,
nil,
nil,
storeCertData.certificate,
false,
"",
},
{
"no cert in store and no bootstrap cert",
nil,
nil,
nil,
nil,
true,
"",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
store := &fakeStore{
cert: tc.storeCert,
}
certResult, shouldRotate, err := getCurrentCertificateOrBootstrap(
store,
tc.bootstrapCertData,
tc.bootstrapKeyData)
if certResult == nil || certResult.Certificate == nil || tc.expectedCert == nil {
if certResult != nil && tc.expectedCert != nil {
t.Errorf("Got certificate %v, wanted %v", certResult, tc.expectedCert)
}
} else {
if !certificatesEqual(certResult, tc.expectedCert) {
t.Errorf("Got certificate %v, wanted %v", certResult, tc.expectedCert)
}
}
if shouldRotate != tc.expectedShouldRotate {
t.Errorf("Got shouldRotate %t, wanted %t", shouldRotate, tc.expectedShouldRotate)
}
if err == nil {
if tc.expectedErrMsg != "" {
t.Errorf("Got err %v, wanted %q", err, tc.expectedErrMsg)
}
} else {
if tc.expectedErrMsg == "" || !strings.Contains(err.Error(), tc.expectedErrMsg) {
t.Errorf("Got err %v, wanted %q", err, tc.expectedErrMsg)
}
}
})
}
}
func TestInitializeCertificateSigningRequestClient(t *testing.T) {
var nilCertificate = &certificateData{}
testCases := []struct {
description string
storeCert *certificateData
bootstrapCert *certificateData
apiCert *certificateData
expectedCertBeforeStart *certificateData
expectedCertAfterStart *certificateData
}{
{
description: "No current certificate, no bootstrap certificate",
storeCert: nilCertificate,
bootstrapCert: nilCertificate,
apiCert: apiServerCertData,
expectedCertBeforeStart: nilCertificate,
expectedCertAfterStart: apiServerCertData,
},
{
description: "No current certificate, bootstrap certificate",
storeCert: nilCertificate,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: bootstrapCertData,
expectedCertAfterStart: apiServerCertData,
},
{
description: "Current certificate, no bootstrap certificate",
storeCert: storeCertData,
bootstrapCert: nilCertificate,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: storeCertData,
},
{
description: "Current certificate, bootstrap certificate",
storeCert: storeCertData,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: storeCertData,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
certificateStore := &fakeStore{
cert: tc.storeCert.certificate,
}
certificateManager, err := NewManager(&Config{
Template: &x509.CertificateRequest{
Subject: pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:fake-node-name",
},
},
Usages: []certificates.KeyUsage{
certificates.UsageDigitalSignature,
certificates.UsageKeyEncipherment,
certificates.UsageClientAuth,
},
CertificateStore: certificateStore,
BootstrapCertificatePEM: tc.bootstrapCert.certificatePEM,
BootstrapKeyPEM: tc.bootstrapCert.keyPEM,
})
if err != nil {
t.Errorf("Got %v, wanted no error.", err)
}
certificate := certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertBeforeStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertBeforeStart.certificate))
}
if err := certificateManager.SetCertificateSigningRequestClient(&fakeClient{
certificatePEM: tc.apiCert.certificatePEM,
}); err != nil {
t.Errorf("Got error %v, expected none.", err)
}
if m, ok := certificateManager.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'")
} else {
m.setRotationDeadline()
if m.shouldRotate() {
if success, err := m.rotateCerts(); !success {
t.Errorf("Got failure from 'rotateCerts', wanted success.")
} else if err != nil {
t.Errorf("Got error %v, expected none.", err)
}
}
}
certificate = certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertAfterStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertAfterStart.certificate))
}
})
}
}
func TestInitializeOtherRESTClients(t *testing.T) {
var nilCertificate = &certificateData{}
testCases := []struct {
description string
storeCert *certificateData
bootstrapCert *certificateData
apiCert *certificateData
expectedCertBeforeStart *certificateData
expectedCertAfterStart *certificateData
}{
{
description: "No current certificate, no bootstrap certificate",
storeCert: nilCertificate,
bootstrapCert: nilCertificate,
apiCert: apiServerCertData,
expectedCertBeforeStart: nilCertificate,
expectedCertAfterStart: apiServerCertData,
},
{
description: "No current certificate, bootstrap certificate",
storeCert: nilCertificate,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: bootstrapCertData,
expectedCertAfterStart: apiServerCertData,
},
{
description: "Current certificate, no bootstrap certificate",
storeCert: storeCertData,
bootstrapCert: nilCertificate,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: storeCertData,
},
{
description: "Current certificate, bootstrap certificate",
storeCert: storeCertData,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: storeCertData,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
certificateStore := &fakeStore{
cert: tc.storeCert.certificate,
}
certificateManager, err := NewManager(&Config{
Template: &x509.CertificateRequest{
Subject: pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:fake-node-name",
},
},
Usages: []certificates.KeyUsage{
certificates.UsageDigitalSignature,
certificates.UsageKeyEncipherment,
certificates.UsageClientAuth,
},
CertificateStore: certificateStore,
BootstrapCertificatePEM: tc.bootstrapCert.certificatePEM,
BootstrapKeyPEM: tc.bootstrapCert.keyPEM,
CertificateSigningRequestClient: &fakeClient{
certificatePEM: tc.apiCert.certificatePEM,
},
})
if err != nil {
t.Errorf("Got %v, wanted no error.", err)
}
certificate := certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertBeforeStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertBeforeStart.certificate))
}
if m, ok := certificateManager.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'")
} else {
m.setRotationDeadline()
if m.shouldRotate() {
success, err := certificateManager.(*manager).rotateCerts()
if err != nil {
t.Errorf("Got error %v, expected none.", err)
return
}
if !success {
t.Errorf("Unexpected response 'rotateCerts': %t", success)
return
}
}
}
certificate = certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertAfterStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertAfterStart.certificate))
}
})
}
}
func TestServerHealth(t *testing.T) {
type certs struct {
storeCert *certificateData
bootstrapCert *certificateData
apiCert *certificateData
expectedCertBeforeStart *certificateData
expectedCertAfterStart *certificateData
}
updatedCerts := certs{
storeCert: storeCertData,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: apiServerCertData,
}
currentCerts := certs{
storeCert: storeCertData,
bootstrapCert: bootstrapCertData,
apiCert: apiServerCertData,
expectedCertBeforeStart: storeCertData,
expectedCertAfterStart: storeCertData,
}
testCases := []struct {
description string
certs
failureType fakeClientFailureType
clientErr error
expectRotateFail bool
expectHealthy bool
}{
{
description: "Current certificate, bootstrap certificate",
certs: updatedCerts,
expectHealthy: true,
},
{
description: "Generic error on create",
certs: currentCerts,
failureType: createError,
expectRotateFail: true,
},
{
description: "Unauthorized error on create",
certs: currentCerts,
failureType: createError,
clientErr: errors.NewUnauthorized("unauthorized"),
expectRotateFail: true,
expectHealthy: true,
},
{
description: "Generic unauthorized error on create",
certs: currentCerts,
failureType: createError,
clientErr: errors.NewGenericServerResponse(401, "POST", schema.GroupResource{}, "", "", 0, true),
expectRotateFail: true,
expectHealthy: true,
},
{
description: "Generic not found error on create",
certs: currentCerts,
failureType: createError,
clientErr: errors.NewGenericServerResponse(404, "POST", schema.GroupResource{}, "", "", 0, true),
expectRotateFail: true,
expectHealthy: false,
},
{
description: "Not found error on create",
certs: currentCerts,
failureType: createError,
clientErr: errors.NewGenericServerResponse(404, "POST", schema.GroupResource{}, "", "", 0, false),
expectRotateFail: true,
expectHealthy: true,
},
{
description: "Conflict error on watch",
certs: currentCerts,
failureType: watchError,
clientErr: errors.NewGenericServerResponse(409, "POST", schema.GroupResource{}, "", "", 0, false),
expectRotateFail: true,
expectHealthy: false,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
certificateStore := &fakeStore{
cert: tc.storeCert.certificate,
}
certificateManager, err := NewManager(&Config{
Template: &x509.CertificateRequest{
Subject: pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:fake-node-name",
},
},
Usages: []certificates.KeyUsage{
certificates.UsageDigitalSignature,
certificates.UsageKeyEncipherment,
certificates.UsageClientAuth,
},
CertificateStore: certificateStore,
BootstrapCertificatePEM: tc.bootstrapCert.certificatePEM,
BootstrapKeyPEM: tc.bootstrapCert.keyPEM,
CertificateSigningRequestClient: &fakeClient{
certificatePEM: tc.apiCert.certificatePEM,
failureType: tc.failureType,
err: tc.clientErr,
},
})
if err != nil {
t.Errorf("Got %v, wanted no error.", err)
}
certificate := certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertBeforeStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertBeforeStart.certificate))
}
if _, ok := certificateManager.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'")
} else {
success, err := certificateManager.(*manager).rotateCerts()
if err != nil {
t.Errorf("Got error %v, expected none.", err)
return
}
if !success != tc.expectRotateFail {
t.Errorf("Unexpected response 'rotateCerts': %t", success)
return
}
if actual := certificateManager.(*manager).ServerHealthy(); actual != tc.expectHealthy {
t.Errorf("Unexpected manager server health: %t", actual)
}
}
certificate = certificateManager.Current()
if !certificatesEqual(certificate, tc.expectedCertAfterStart.certificate) {
t.Errorf("Got %v, wanted %v", certificateString(certificate), certificateString(tc.expectedCertAfterStart.certificate))
}
})
}
}
type fakeClientFailureType int
const (
none fakeClientFailureType = iota
createError
watchError
certificateSigningRequestDenied
)
type fakeClient struct {
certificatesclient.CertificateSigningRequestInterface
failureType fakeClientFailureType
certificatePEM []byte
err error
}
func (c fakeClient) List(opts v1.ListOptions) (*certificates.CertificateSigningRequestList, error) {
if c.failureType == watchError {
if c.err != nil {
return nil, c.err
}
return nil, fmt.Errorf("Watch error")
}
csrReply := certificates.CertificateSigningRequestList{
Items: []certificates.CertificateSigningRequest{
{ObjectMeta: v1.ObjectMeta{UID: "fake-uid"}},
},
}
return &csrReply, nil
}
func (c fakeClient) Create(*certificates.CertificateSigningRequest) (*certificates.CertificateSigningRequest, error) {
if c.failureType == createError {
if c.err != nil {
return nil, c.err
}
return nil, fmt.Errorf("Create error")
}
csrReply := certificates.CertificateSigningRequest{}
csrReply.UID = "fake-uid"
return &csrReply, nil
}
func (c fakeClient) Watch(opts v1.ListOptions) (watch.Interface, error) {
if c.failureType == watchError {
if c.err != nil {
return nil, c.err
}
return nil, fmt.Errorf("Watch error")
}
return &fakeWatch{
failureType: c.failureType,
certificatePEM: c.certificatePEM,
}, nil
}
type fakeWatch struct {
failureType fakeClientFailureType
certificatePEM []byte
}
func (w *fakeWatch) Stop() {
}
func (w *fakeWatch) ResultChan() <-chan watch.Event {
var condition certificates.CertificateSigningRequestCondition
if w.failureType == certificateSigningRequestDenied {
condition = certificates.CertificateSigningRequestCondition{
Type: certificates.CertificateDenied,
}
} else {
condition = certificates.CertificateSigningRequestCondition{
Type: certificates.CertificateApproved,
}
}
csr := certificates.CertificateSigningRequest{
Status: certificates.CertificateSigningRequestStatus{
Conditions: []certificates.CertificateSigningRequestCondition{
condition,
},
Certificate: []byte(w.certificatePEM),
},
}
csr.UID = "fake-uid"
c := make(chan watch.Event, 1)
c <- watch.Event{
Type: watch.Added,
Object: &csr,
}
return c
}
type fakeStore struct {
cert *tls.Certificate
}
func (s *fakeStore) Current() (*tls.Certificate, error) {
if s.cert == nil {
noKeyErr := NoCertKeyError("")
return nil, &noKeyErr
}
return s.cert, nil
}
// Accepts the PEM data for the cert/key pair and makes the new cert/key
// pair the 'current' pair, that will be returned by future calls to
// Current().
func (s *fakeStore) Update(certPEM, keyPEM []byte) (*tls.Certificate, error) {
// In order to make the mocking work, whenever a cert/key pair is passed in
// to be updated in the mock store, assume that the certificate manager
// generated the key, and then asked the mock CertificateSigningRequest API
// to sign it, then the faked API returned a canned response. The canned
// signing response will not match the generated key. In order to make
// things work out, search here for the correct matching key and use that
// instead of the passed in key. That way this file of test code doesn't
// have to implement an actual certificate signing process.
for _, tc := range []*certificateData{storeCertData, bootstrapCertData, apiServerCertData} {
if bytes.Equal(tc.certificatePEM, certPEM) {
keyPEM = tc.keyPEM
}
}
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, err
}
now := time.Now()
s.cert = &cert
s.cert.Leaf = &x509.Certificate{
NotBefore: now.Add(-24 * time.Hour),
NotAfter: now.Add(24 * time.Hour),
}
return s.cert, nil
}
func certificatesEqual(c1 *tls.Certificate, c2 *tls.Certificate) bool {
if c1 == nil || c2 == nil {
return c1 == c2
}
if len(c1.Certificate) != len(c2.Certificate) {
return false
}
for i := 0; i < len(c1.Certificate); i++ {
if !bytes.Equal(c1.Certificate[i], c2.Certificate[i]) {
return false
}
}
return true
}
func certificateString(c *tls.Certificate) string {
if c == nil {
return "certificate == nil"
}
if c.Leaf == nil {
return "certificate.Leaf == nil"
}
return c.Leaf.Subject.CommonName
}

View File

@ -1,324 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package certificate
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"time"
"github.com/golang/glog"
)
const (
keyExtension = ".key"
certExtension = ".crt"
pemExtension = ".pem"
currentPair = "current"
updatedPair = "updated"
)
type fileStore struct {
pairNamePrefix string
certDirectory string
keyDirectory string
certFile string
keyFile string
}
// NewFileStore returns a concrete implementation of a Store that is based on
// storing the cert/key pairs in a single file per pair on disk in the
// designated directory. When starting up it will look for the currently
// selected cert/key pair in:
//
// 1. ${certDirectory}/${pairNamePrefix}-current.pem - both cert and key are in the same file.
// 2. ${certFile}, ${keyFile}
// 3. ${certDirectory}/${pairNamePrefix}.crt, ${keyDirectory}/${pairNamePrefix}.key
//
// The first one found will be used. If rotation is enabled, future cert/key
// updates will be written to the ${certDirectory} directory and
// ${certDirectory}/${pairNamePrefix}-current.pem will be created as a soft
// link to the currently selected cert/key pair.
func NewFileStore(
pairNamePrefix string,
certDirectory string,
keyDirectory string,
certFile string,
keyFile string) (Store, error) {
s := fileStore{
pairNamePrefix: pairNamePrefix,
certDirectory: certDirectory,
keyDirectory: keyDirectory,
certFile: certFile,
keyFile: keyFile,
}
if err := s.recover(); err != nil {
return nil, err
}
return &s, nil
}
// recover checks if there is a certificate rotation that was interrupted while
// progress, and if so, attempts to recover to a good state.
func (s *fileStore) recover() error {
// If the 'current' file doesn't exist, continue on with the recovery process.
currentPath := filepath.Join(s.certDirectory, s.filename(currentPair))
if exists, err := fileExists(currentPath); err != nil {
return err
} else if exists {
return nil
}
// If the 'updated' file exists, and it is a symbolic link, continue on
// with the recovery process.
updatedPath := filepath.Join(s.certDirectory, s.filename(updatedPair))
if fi, err := os.Lstat(updatedPath); err != nil {
if os.IsNotExist(err) {
return nil
}
return err
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
return fmt.Errorf("expected %q to be a symlink but it is a file", updatedPath)
}
// Move the 'updated' symlink to 'current'.
if err := os.Rename(updatedPath, currentPath); err != nil {
return fmt.Errorf("unable to rename %q to %q: %v", updatedPath, currentPath, err)
}
return nil
}
func (s *fileStore) Current() (*tls.Certificate, error) {
pairFile := filepath.Join(s.certDirectory, s.filename(currentPair))
if pairFileExists, err := fileExists(pairFile); err != nil {
return nil, err
} else if pairFileExists {
glog.Infof("Loading cert/key pair from %q.", pairFile)
return loadFile(pairFile)
}
certFileExists, err := fileExists(s.certFile)
if err != nil {
return nil, err
}
keyFileExists, err := fileExists(s.keyFile)
if err != nil {
return nil, err
}
if certFileExists && keyFileExists {
glog.Infof("Loading cert/key pair from (%q, %q).", s.certFile, s.keyFile)
return loadX509KeyPair(s.certFile, s.keyFile)
}
c := filepath.Join(s.certDirectory, s.pairNamePrefix+certExtension)
k := filepath.Join(s.keyDirectory, s.pairNamePrefix+keyExtension)
certFileExists, err = fileExists(c)
if err != nil {
return nil, err
}
keyFileExists, err = fileExists(k)
if err != nil {
return nil, err
}
if certFileExists && keyFileExists {
glog.Infof("Loading cert/key pair from (%q, %q).", c, k)
return loadX509KeyPair(c, k)
}
noKeyErr := NoCertKeyError(
fmt.Sprintf("no cert/key files read at %q, (%q, %q) or (%q, %q)",
pairFile,
s.certFile,
s.keyFile,
s.certDirectory,
s.keyDirectory))
return nil, &noKeyErr
}
func loadFile(pairFile string) (*tls.Certificate, error) {
certBlock, keyBlock, err := loadCertKeyBlocks(pairFile)
if err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
if err != nil {
return nil, fmt.Errorf("could not convert data from %q into cert/key pair: %v", pairFile, err)
}
certs, err := x509.ParseCertificates(cert.Certificate[0])
if err != nil {
return nil, fmt.Errorf("unable to parse certificate data: %v", err)
}
cert.Leaf = certs[0]
return &cert, nil
}
func loadCertKeyBlocks(pairFile string) (cert *pem.Block, key *pem.Block, err error) {
data, err := ioutil.ReadFile(pairFile)
if err != nil {
return nil, nil, fmt.Errorf("could not load cert/key pair from %q: %v", pairFile, err)
}
certBlock, rest := pem.Decode(data)
if certBlock == nil {
return nil, nil, fmt.Errorf("could not decode the first block from %q from expected PEM format", pairFile)
}
keyBlock, _ := pem.Decode(rest)
if keyBlock == nil {
return nil, nil, fmt.Errorf("could not decode the second block from %q from expected PEM format", pairFile)
}
return certBlock, keyBlock, nil
}
func (s *fileStore) Update(certData, keyData []byte) (*tls.Certificate, error) {
ts := time.Now().Format("2006-01-02-15-04-05")
pemFilename := s.filename(ts)
if err := os.MkdirAll(s.certDirectory, 0755); err != nil {
return nil, fmt.Errorf("could not create directory %q to store certificates: %v", s.certDirectory, err)
}
certPath := filepath.Join(s.certDirectory, pemFilename)
f, err := os.OpenFile(certPath, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("could not open %q: %v", certPath, err)
}
defer f.Close()
certBlock, _ := pem.Decode(certData)
if certBlock == nil {
return nil, fmt.Errorf("invalid certificate data")
}
pem.Encode(f, certBlock)
keyBlock, _ := pem.Decode(keyData)
if keyBlock == nil {
return nil, fmt.Errorf("invalid key data")
}
pem.Encode(f, keyBlock)
cert, err := loadFile(certPath)
if err != nil {
return nil, err
}
if err := s.updateSymlink(certPath); err != nil {
return nil, err
}
return cert, nil
}
// updateSymLink updates the current symlink to point to the file that is
// passed it. It will fail if there is a non-symlink file exists where the
// symlink is expected to be.
func (s *fileStore) updateSymlink(filename string) error {
// If the 'current' file either doesn't exist, or is already a symlink,
// proceed. Otherwise, this is an unrecoverable error.
currentPath := filepath.Join(s.certDirectory, s.filename(currentPair))
currentPathExists := false
if fi, err := os.Lstat(currentPath); err != nil {
if !os.IsNotExist(err) {
return err
}
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
return fmt.Errorf("expected %q to be a symlink but it is a file", currentPath)
} else {
currentPathExists = true
}
// If the 'updated' file doesn't exist, proceed. If it exists but it is a
// symlink, delete it. Otherwise, this is an unrecoverable error.
updatedPath := filepath.Join(s.certDirectory, s.filename(updatedPair))
if fi, err := os.Lstat(updatedPath); err != nil {
if !os.IsNotExist(err) {
return err
}
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
return fmt.Errorf("expected %q to be a symlink but it is a file", updatedPath)
} else {
if err := os.Remove(updatedPath); err != nil {
return fmt.Errorf("unable to remove %q: %v", updatedPath, err)
}
}
// Check that the new cert/key pair file exists to avoid rotating to an
// invalid cert/key.
if filenameExists, err := fileExists(filename); err != nil {
return err
} else if !filenameExists {
return fmt.Errorf("file %q does not exist so it can not be used as the currently selected cert/key", filename)
}
// Ensure the source path is absolute to ensure the symlink target is
// correct when certDirectory is a relative path.
filename, err := filepath.Abs(filename)
if err != nil {
return err
}
// Create the 'updated' symlink pointing to the requested file name.
if err := os.Symlink(filename, updatedPath); err != nil {
return fmt.Errorf("unable to create a symlink from %q to %q: %v", updatedPath, filename, err)
}
// Replace the 'current' symlink.
if currentPathExists {
if err := os.Remove(currentPath); err != nil {
return fmt.Errorf("unable to remove %q: %v", currentPath, err)
}
}
if err := os.Rename(updatedPath, currentPath); err != nil {
return fmt.Errorf("unable to rename %q to %q: %v", updatedPath, currentPath, err)
}
return nil
}
func (s *fileStore) filename(qualifier string) string {
return s.pairNamePrefix + "-" + qualifier + pemExtension
}
// withoutExt returns the given filename after removing the extension. The
// extension to remove will be the result of filepath.Ext().
func withoutExt(filename string) string {
return strings.TrimSuffix(filename, filepath.Ext(filename))
}
func loadX509KeyPair(certFile, keyFile string) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
certs, err := x509.ParseCertificates(cert.Certificate[0])
if err != nil {
return nil, fmt.Errorf("unable to parse certificate data: %v", err)
}
cert.Leaf = certs[0]
return &cert, nil
}
// FileExists checks if specified file exists.
func fileExists(filename string) (bool, error) {
if _, err := os.Stat(filename); os.IsNotExist(err) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}

View File

@ -1,505 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package certificate
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"k8s.io/client-go/util/cert"
)
func TestUpdateSymlinkExistingFileError(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-update-symlink")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-current.pem")
if err := ioutil.WriteFile(pairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
s := fileStore{
certDirectory: dir,
pairNamePrefix: "kubelet",
}
if err := s.updateSymlink(pairFile); err == nil {
t.Errorf("Got no error, wanted to fail updating the symlink because there is a file there.")
}
}
func TestUpdateSymlinkNewFileNotExist(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-update-symlink")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
oldPairFile := filepath.Join(dir, "kubelet-oldpair.pem")
if err := ioutil.WriteFile(oldPairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", oldPairFile, err)
}
s := fileStore{
certDirectory: dir,
pairNamePrefix: "kubelet",
}
if err := s.updateSymlink(oldPairFile); err != nil {
t.Errorf("Got %v, wanted successful update of the symlink to point to %q", err, oldPairFile)
}
if _, err := os.Stat(oldPairFile); err != nil {
t.Errorf("Got %v, wanted file %q to be there.", oldPairFile, err)
}
currentPairFile := filepath.Join(dir, "kubelet-current.pem")
if fi, err := os.Lstat(currentPairFile); err != nil {
t.Errorf("Got %v, wanted file %q to be there", currentPairFile, err)
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
t.Errorf("Got %q not a symlink.", currentPairFile)
}
newPairFile := filepath.Join(dir, "kubelet-newpair.pem")
if err := s.updateSymlink(newPairFile); err == nil {
t.Errorf("Got no error, wanted to fail updating the symlink the file %q does not exist.", newPairFile)
}
}
func TestUpdateSymlinkNoSymlink(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-update-symlink")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-newfile.pem")
if err := ioutil.WriteFile(pairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
s := fileStore{
certDirectory: dir,
pairNamePrefix: "kubelet",
}
if err := s.updateSymlink(pairFile); err != nil {
t.Errorf("Got error %v, wanted a new symlink to be created", err)
}
if _, err := os.Stat(pairFile); err != nil {
t.Errorf("Got error %v, wanted file %q to be there", pairFile, err)
}
currentPairFile := filepath.Join(dir, "kubelet-current.pem")
if fi, err := os.Lstat(currentPairFile); err != nil {
t.Errorf("Got %v, wanted %q to be there", currentPairFile, err)
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
t.Errorf("%q not a symlink, wanted a symlink.", currentPairFile)
}
}
func TestUpdateSymlinkReplaceExistingSymlink(t *testing.T) {
prefix := "kubelet"
dir, err := ioutil.TempDir("", "k8s-test-update-symlink")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
oldPairFile := filepath.Join(dir, prefix+"-oldfile.pem")
if err := ioutil.WriteFile(oldPairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", oldPairFile, err)
}
newPairFile := filepath.Join(dir, prefix+"-newfile.pem")
if err := ioutil.WriteFile(newPairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", newPairFile, err)
}
currentPairFile := filepath.Join(dir, prefix+"-current.pem")
if err := os.Symlink(oldPairFile, currentPairFile); err != nil {
t.Fatalf("unable to create a symlink from %q to %q: %v", currentPairFile, oldPairFile, err)
}
if resolved, err := os.Readlink(currentPairFile); err != nil {
t.Fatalf("Got %v when attempting to resolve symlink %q", err, currentPairFile)
} else if resolved != oldPairFile {
t.Fatalf("Got %q as resolution of symlink %q, wanted %q", resolved, currentPairFile, oldPairFile)
}
s := fileStore{
certDirectory: dir,
pairNamePrefix: prefix,
}
if err := s.updateSymlink(newPairFile); err != nil {
t.Errorf("Got error %v, wanted a new symlink to be created", err)
}
if _, err := os.Stat(oldPairFile); err != nil {
t.Errorf("Got error %v, wanted file %q to be there", oldPairFile, err)
}
if _, err := os.Stat(newPairFile); err != nil {
t.Errorf("Got error %v, wanted file %q to be there", newPairFile, err)
}
if fi, err := os.Lstat(currentPairFile); err != nil {
t.Errorf("Got %v, wanted %q to be there", currentPairFile, err)
} else if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
t.Errorf("%q not a symlink, wanted a symlink.", currentPairFile)
}
if resolved, err := os.Readlink(currentPairFile); err != nil {
t.Fatalf("Got %v when attempting to resolve symlink %q", err, currentPairFile)
} else if resolved != newPairFile {
t.Fatalf("Got %q as resolution of symlink %q, wanted %q", resolved, currentPairFile, newPairFile)
}
}
func TestLoadCertKeyBlocksNoFile(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-load-cert-key-blocks")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-pair.pem")
if _, _, err := loadCertKeyBlocks(pairFile); err == nil {
t.Errorf("Got no error, but expected %q not found.", pairFile)
}
}
func TestLoadCertKeyBlocksEmptyFile(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-load-cert-key-blocks")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-pair.pem")
if err := ioutil.WriteFile(pairFile, nil, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
if _, _, err := loadCertKeyBlocks(pairFile); err == nil {
t.Errorf("Got no error, but expected %q not found.", pairFile)
}
}
func TestLoadCertKeyBlocksPartialFile(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-load-cert-key-blocks")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-pair.pem")
if err := ioutil.WriteFile(pairFile, storeCertData.certificatePEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
if _, _, err := loadCertKeyBlocks(pairFile); err == nil {
t.Errorf("Got no error, but expected %q invalid.", pairFile)
}
}
func TestLoadCertKeyBlocks(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-load-cert-key-blocks")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-pair.pem")
data := append(storeCertData.certificatePEM, []byte("\n")...)
data = append(data, storeCertData.keyPEM...)
if err := ioutil.WriteFile(pairFile, data, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
certBlock, keyBlock, err := loadCertKeyBlocks(pairFile)
if err != nil {
t.Errorf("Got %v, but expected no error.", pairFile)
}
if certBlock.Type != cert.CertificateBlockType {
t.Errorf("Got %q loaded from the pair file, expected a %q.", certBlock.Type, cert.CertificateBlockType)
}
if keyBlock.Type != cert.RSAPrivateKeyBlockType {
t.Errorf("Got %q loaded from the pair file, expected a %q.", keyBlock.Type, cert.RSAPrivateKeyBlockType)
}
}
func TestLoadFile(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-load-cert-key-blocks")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, "kubelet-pair.pem")
data := append(storeCertData.certificatePEM, []byte("\n")...)
data = append(data, storeCertData.keyPEM...)
if err := ioutil.WriteFile(pairFile, data, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
cert, err := loadFile(pairFile)
if err != nil {
t.Fatalf("Could not load certificate from disk: %v", err)
}
if cert == nil {
t.Fatalf("There was no error, but no certificate data was returned.")
}
if cert.Leaf == nil {
t.Fatalf("Got an empty leaf, expected private data.")
}
}
func TestUpdateNoRotation(t *testing.T) {
prefix := "kubelet-server"
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
keyFile := filepath.Join(dir, "kubelet.key")
if err := ioutil.WriteFile(keyFile, storeCertData.keyPEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", keyFile, err)
}
certFile := filepath.Join(dir, "kubelet.crt")
if err := ioutil.WriteFile(certFile, storeCertData.certificatePEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", certFile, err)
}
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil {
t.Fatalf("Got %v while creating a new store.", err)
}
cert, err := s.Update(storeCertData.certificatePEM, storeCertData.keyPEM)
if err != nil {
t.Errorf("Got %v while updating certificate store.", err)
}
if cert == nil {
t.Errorf("Got nil certificate, expected something real.")
}
}
func TestUpdateRotation(t *testing.T) {
prefix := "kubelet-server"
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
keyFile := filepath.Join(dir, "kubelet.key")
if err := ioutil.WriteFile(keyFile, storeCertData.keyPEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", keyFile, err)
}
certFile := filepath.Join(dir, "kubelet.crt")
if err := ioutil.WriteFile(certFile, storeCertData.certificatePEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", certFile, err)
}
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil {
t.Fatalf("Got %v while creating a new store.", err)
}
cert, err := s.Update(storeCertData.certificatePEM, storeCertData.keyPEM)
if err != nil {
t.Fatalf("Got %v while updating certificate store.", err)
}
if cert == nil {
t.Fatalf("Got nil certificate, expected something real.")
}
}
func TestUpdateWithBadCertKeyData(t *testing.T) {
prefix := "kubelet-server"
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
keyFile := filepath.Join(dir, "kubelet.key")
if err := ioutil.WriteFile(keyFile, storeCertData.keyPEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", keyFile, err)
}
certFile := filepath.Join(dir, "kubelet.crt")
if err := ioutil.WriteFile(certFile, storeCertData.certificatePEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", certFile, err)
}
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil {
t.Fatalf("Got %v while creating a new store.", err)
}
cert, err := s.Update([]byte{0, 0}, storeCertData.keyPEM)
if err == nil {
t.Fatalf("Got no error while updating certificate store with invalid data.")
}
if cert != nil {
t.Fatalf("Got %v certificate returned from the update, expected nil.", cert)
}
}
func TestCurrentPairFile(t *testing.T) {
prefix := "kubelet-server"
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
pairFile := filepath.Join(dir, prefix+"-pair.pem")
data := append(storeCertData.certificatePEM, []byte("\n")...)
data = append(data, storeCertData.keyPEM...)
if err := ioutil.WriteFile(pairFile, data, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", pairFile, err)
}
currentFile := filepath.Join(dir, prefix+"-current.pem")
if err := os.Symlink(pairFile, currentFile); err != nil {
t.Fatalf("unable to create a symlink from %q to %q: %v", currentFile, pairFile, err)
}
store, err := NewFileStore("kubelet-server", dir, dir, "", "")
if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err)
}
cert, err := store.Current()
if err != nil {
t.Fatalf("Could not load certificate from disk: %v", err)
}
if cert == nil {
t.Fatalf("There was no error, but no certificate data was returned.")
}
if cert.Leaf == nil {
t.Fatalf("Got an empty leaf, expected private data.")
}
}
func TestCurrentCertKeyFiles(t *testing.T) {
prefix := "kubelet-server"
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
certFile := filepath.Join(dir, "kubelet.crt")
if err := ioutil.WriteFile(certFile, storeCertData.certificatePEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", certFile, err)
}
keyFile := filepath.Join(dir, "kubelet.key")
if err := ioutil.WriteFile(keyFile, storeCertData.keyPEM, 0600); err != nil {
t.Fatalf("Unable to create the file %q: %v", keyFile, err)
}
store, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err)
}
cert, err := store.Current()
if err != nil {
t.Fatalf("Could not load certificate from disk: %v", err)
}
if cert == nil {
t.Fatalf("There was no error, but no certificate data was returned.")
}
if cert.Leaf == nil {
t.Fatalf("Got an empty leaf, expected private data.")
}
}
func TestCurrentNoFiles(t *testing.T) {
dir, err := ioutil.TempDir("", "k8s-test-certstore-current")
if err != nil {
t.Fatalf("Unable to create the test directory %q: %v", dir, err)
}
defer func() {
if err := os.RemoveAll(dir); err != nil {
t.Errorf("Unable to clean up test directory %q: %v", dir, err)
}
}()
store, err := NewFileStore("kubelet-server", dir, dir, "", "")
if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err)
}
cert, err := store.Current()
if err == nil {
t.Fatalf("Got no error, expected an error because the cert/key files don't exist.")
}
if _, ok := err.(*NoCertKeyError); !ok {
t.Fatalf("Got error %v, expected NoCertKeyError.", err)
}
if cert != nil {
t.Fatalf("Got certificate, expected no certificate because the cert/key files don't exist.")
}
}

View File

@ -1,54 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = ["csr.go"],
importpath = "k8s.io/client-go/util/certificate/csr",
deps = [
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/k8s.io/api/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/fields:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/runtime:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/types:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/watch:go_default_library",
"//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/client-go/tools/cache:go_default_library",
"//vendor/k8s.io/client-go/util/cert:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)
go_test(
name = "go_default_test",
srcs = ["csr_test.go"],
importpath = "k8s.io/client-go/util/certificate/csr",
library = ":go_default_library",
deps = [
"//vendor/k8s.io/api/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/watch:go_default_library",
"//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/client-go/util/cert:go_default_library",
],
)

View File

@ -1,261 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package csr
import (
"crypto"
"crypto/sha512"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"github.com/golang/glog"
"reflect"
"time"
certificates "k8s.io/api/certificates/v1beta1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch"
certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1"
"k8s.io/client-go/tools/cache"
certutil "k8s.io/client-go/util/cert"
)
// RequestNodeCertificate will create a certificate signing request for a node
// (Organization and CommonName for the CSR will be set as expected for node
// certificates) and send it to API server, then it will watch the object's
// status, once approved by API server, it will return the API server's issued
// certificate (pem-encoded). If there is any errors, or the watch timeouts, it
// will return an error. This is intended for use on nodes (kubelet and
// kubeadm).
func RequestNodeCertificate(client certificatesclient.CertificateSigningRequestInterface, privateKeyData []byte, nodeName types.NodeName) (certData []byte, err error) {
subject := &pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:" + string(nodeName),
}
privateKey, err := certutil.ParsePrivateKeyPEM(privateKeyData)
if err != nil {
return nil, fmt.Errorf("invalid private key for certificate request: %v", err)
}
csrData, err := certutil.MakeCSR(privateKey, subject, nil, nil)
if err != nil {
return nil, fmt.Errorf("unable to generate certificate request: %v", err)
}
usages := []certificates.KeyUsage{
certificates.UsageDigitalSignature,
certificates.UsageKeyEncipherment,
certificates.UsageClientAuth,
}
name := digestedName(privateKeyData, subject, usages)
req, err := RequestCertificate(client, csrData, name, usages, privateKey)
if err != nil {
return nil, err
}
return WaitForCertificate(client, req, 3600*time.Second)
}
// RequestCertificate will either use an existing (if this process has run
// before but not to completion) or create a certificate signing request using the
// PEM encoded CSR and send it to API server, then it will watch the object's
// status, once approved by API server, it will return the API server's issued
// certificate (pem-encoded). If there is any errors, or the watch timeouts, it
// will return an error.
func RequestCertificate(client certificatesclient.CertificateSigningRequestInterface, csrData []byte, name string, usages []certificates.KeyUsage, privateKey interface{}) (req *certificates.CertificateSigningRequest, err error) {
csr := &certificates.CertificateSigningRequest{
// Username, UID, Groups will be injected by API server.
TypeMeta: metav1.TypeMeta{Kind: "CertificateSigningRequest"},
ObjectMeta: metav1.ObjectMeta{
Name: name,
},
Spec: certificates.CertificateSigningRequestSpec{
Request: csrData,
Usages: usages,
},
}
if len(csr.Name) == 0 {
csr.GenerateName = "csr-"
}
req, err = client.Create(csr)
switch {
case err == nil:
case errors.IsAlreadyExists(err) && len(name) > 0:
glog.Infof("csr for this node already exists, reusing")
req, err = client.Get(name, metav1.GetOptions{})
if err != nil {
return nil, formatError("cannot retrieve certificate signing request: %v", err)
}
if err := ensureCompatible(req, csr, privateKey); err != nil {
return nil, fmt.Errorf("retrieved csr is not compatible: %v", err)
}
glog.Infof("csr for this node is still valid")
default:
return nil, formatError("cannot create certificate signing request: %v", err)
}
return req, nil
}
// WaitForCertificate waits for a certificate to be issued until timeout, or returns an error.
func WaitForCertificate(client certificatesclient.CertificateSigningRequestInterface, req *certificates.CertificateSigningRequest, timeout time.Duration) (certData []byte, err error) {
fieldSelector := fields.OneTermEqualSelector("metadata.name", req.Name).String()
event, err := cache.ListWatchUntil(
timeout,
&cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
options.FieldSelector = fieldSelector
return client.List(options)
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
options.FieldSelector = fieldSelector
return client.Watch(options)
},
},
func(event watch.Event) (bool, error) {
switch event.Type {
case watch.Modified, watch.Added:
case watch.Deleted:
return false, fmt.Errorf("csr %q was deleted", req.Name)
default:
return false, nil
}
csr := event.Object.(*certificates.CertificateSigningRequest)
if csr.UID != req.UID {
return false, fmt.Errorf("csr %q changed UIDs", csr.Name)
}
for _, c := range csr.Status.Conditions {
if c.Type == certificates.CertificateDenied {
return false, fmt.Errorf("certificate signing request is not approved, reason: %v, message: %v", c.Reason, c.Message)
}
if c.Type == certificates.CertificateApproved && csr.Status.Certificate != nil {
return true, nil
}
}
return false, nil
},
)
if err == wait.ErrWaitTimeout {
return nil, wait.ErrWaitTimeout
}
if err != nil {
return nil, formatError("cannot watch on the certificate signing request: %v", err)
}
return event.Object.(*certificates.CertificateSigningRequest).Status.Certificate, nil
}
// This digest should include all the relevant pieces of the CSR we care about.
// We can't direcly hash the serialized CSR because of random padding that we
// regenerate every loop and we include usages which are not contained in the
// CSR. This needs to be kept up to date as we add new fields to the node
// certificates and with ensureCompatible.
func digestedName(privateKeyData []byte, subject *pkix.Name, usages []certificates.KeyUsage) string {
hash := sha512.New512_256()
// Here we make sure two different inputs can't write the same stream
// to the hash. This delimiter is not in the base64.URLEncoding
// alphabet so there is no way to have spill over collisions. Without
// it 'CN:foo,ORG:bar' hashes to the same value as 'CN:foob,ORG:ar'
const delimiter = '|'
encode := base64.RawURLEncoding.EncodeToString
write := func(data []byte) {
hash.Write([]byte(encode(data)))
hash.Write([]byte{delimiter})
}
write(privateKeyData)
write([]byte(subject.CommonName))
for _, v := range subject.Organization {
write([]byte(v))
}
for _, v := range usages {
write([]byte(v))
}
return "node-csr-" + encode(hash.Sum(nil))
}
// ensureCompatible ensures that a CSR object is compatible with an original CSR
func ensureCompatible(new, orig *certificates.CertificateSigningRequest, privateKey interface{}) error {
newCsr, err := ParseCSR(new)
if err != nil {
return fmt.Errorf("unable to parse new csr: %v", err)
}
origCsr, err := ParseCSR(orig)
if err != nil {
return fmt.Errorf("unable to parse original csr: %v", err)
}
if !reflect.DeepEqual(newCsr.Subject, origCsr.Subject) {
return fmt.Errorf("csr subjects differ: new: %#v, orig: %#v", newCsr.Subject, origCsr.Subject)
}
signer, ok := privateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("privateKey is not a signer")
}
newCsr.PublicKey = signer.Public()
if err := newCsr.CheckSignature(); err != nil {
return fmt.Errorf("error validating signature new CSR against old key: %v", err)
}
if len(new.Status.Certificate) > 0 {
certs, err := certutil.ParseCertsPEM(new.Status.Certificate)
if err != nil {
return fmt.Errorf("error parsing signed certificate for CSR: %v", err)
}
now := time.Now()
for _, cert := range certs {
if now.After(cert.NotAfter) {
return fmt.Errorf("one of the certificates for the CSR has expired: %s", cert.NotAfter)
}
}
}
return nil
}
// formatError preserves the type of an API message but alters the message. Expects
// a single argument format string, and returns the wrapped error.
func formatError(format string, err error) error {
if s, ok := err.(errors.APIStatus); ok {
se := &errors.StatusError{ErrStatus: s.Status()}
se.ErrStatus.Message = fmt.Sprintf(format, se.ErrStatus.Message)
return se
}
return fmt.Errorf(format, err)
}
// ParseCSR extracts the CSR from the API object and decodes it.
func ParseCSR(obj *certificates.CertificateSigningRequest) (*x509.CertificateRequest, error) {
// extract PEM from request object
pemBytes := obj.Spec.Request
block, _ := pem.Decode(pemBytes)
if block == nil || block.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("PEM block type must be CERTIFICATE REQUEST")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, err
}
return csr, nil
}

View File

@ -1,135 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package csr
import (
"fmt"
"testing"
certificates "k8s.io/api/certificates/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
watch "k8s.io/apimachinery/pkg/watch"
certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1"
certutil "k8s.io/client-go/util/cert"
)
func TestRequestNodeCertificateNoKeyData(t *testing.T) {
certData, err := RequestNodeCertificate(&fakeClient{}, []byte{}, "fake-node-name")
if err == nil {
t.Errorf("Got no error, wanted error an error because there was an empty private key passed in.")
}
if certData != nil {
t.Errorf("Got cert data, wanted nothing as there should have been an error.")
}
}
func TestRequestNodeCertificateErrorCreatingCSR(t *testing.T) {
client := &fakeClient{
failureType: createError,
}
privateKeyData, err := certutil.MakeEllipticPrivateKeyPEM()
if err != nil {
t.Fatalf("Unable to generate a new private key: %v", err)
}
certData, err := RequestNodeCertificate(client, privateKeyData, "fake-node-name")
if err == nil {
t.Errorf("Got no error, wanted error an error because client.Create failed.")
}
if certData != nil {
t.Errorf("Got cert data, wanted nothing as there should have been an error.")
}
}
func TestRequestNodeCertificate(t *testing.T) {
privateKeyData, err := certutil.MakeEllipticPrivateKeyPEM()
if err != nil {
t.Fatalf("Unable to generate a new private key: %v", err)
}
certData, err := RequestNodeCertificate(&fakeClient{}, privateKeyData, "fake-node-name")
if err != nil {
t.Errorf("Got %v, wanted no error.", err)
}
if certData == nil {
t.Errorf("Got nothing, expected a CSR.")
}
}
type FailureType int
const (
noError FailureType = iota
createError
certificateSigningRequestDenied
)
type fakeClient struct {
certificatesclient.CertificateSigningRequestInterface
watch *watch.FakeWatcher
failureType FailureType
}
func (c *fakeClient) Create(*certificates.CertificateSigningRequest) (*certificates.CertificateSigningRequest, error) {
if c.failureType == createError {
return nil, fmt.Errorf("fakeClient failed creating request")
}
csr := certificates.CertificateSigningRequest{
ObjectMeta: metav1.ObjectMeta{
UID: "fake-uid",
Name: "fake-certificate-signing-request-name",
},
}
return &csr, nil
}
func (c *fakeClient) List(opts metav1.ListOptions) (*certificates.CertificateSigningRequestList, error) {
return &certificates.CertificateSigningRequestList{}, nil
}
func (c *fakeClient) Watch(opts metav1.ListOptions) (watch.Interface, error) {
c.watch = watch.NewFakeWithChanSize(1, false)
c.watch.Add(c.generateCSR())
c.watch.Stop()
return c.watch, nil
}
func (c *fakeClient) generateCSR() *certificates.CertificateSigningRequest {
var condition certificates.CertificateSigningRequestCondition
if c.failureType == certificateSigningRequestDenied {
condition = certificates.CertificateSigningRequestCondition{
Type: certificates.CertificateDenied,
}
} else {
condition = certificates.CertificateSigningRequestCondition{
Type: certificates.CertificateApproved,
}
}
csr := certificates.CertificateSigningRequest{
ObjectMeta: metav1.ObjectMeta{
UID: "fake-uid",
},
Status: certificates.CertificateSigningRequestStatus{
Conditions: []certificates.CertificateSigningRequestCondition{
condition,
},
Certificate: []byte{},
},
}
return &csr
}

View File

@ -1,25 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["exec.go"],
importpath = "k8s.io/client-go/util/exec",
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,52 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package exec
// ExitError is an interface that presents an API similar to os.ProcessState, which is
// what ExitError from os/exec is. This is designed to make testing a bit easier and
// probably loses some of the cross-platform properties of the underlying library.
type ExitError interface {
String() string
Error() string
Exited() bool
ExitStatus() int
}
// CodeExitError is an implementation of ExitError consisting of an error object
// and an exit code (the upper bits of os.exec.ExitStatus).
type CodeExitError struct {
Err error
Code int
}
var _ ExitError = CodeExitError{}
func (e CodeExitError) Error() string {
return e.Err.Error()
}
func (e CodeExitError) String() string {
return e.Err.Error()
}
func (e CodeExitError) Exited() bool {
return true
}
func (e CodeExitError) ExitStatus() int {
return e.Code
}

View File

@ -1,45 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"backoff_test.go",
"throttle_test.go",
],
importpath = "k8s.io/client-go/util/flowcontrol",
library = ":go_default_library",
deps = ["//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library"],
)
go_library(
name = "go_default_library",
srcs = [
"backoff.go",
"throttle.go",
],
importpath = "k8s.io/client-go/util/flowcontrol",
deps = [
"//vendor/github.com/juju/ratelimit:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//vendor/k8s.io/client-go/util/integer:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,149 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flowcontrol
import (
"sync"
"time"
"k8s.io/apimachinery/pkg/util/clock"
"k8s.io/client-go/util/integer"
)
type backoffEntry struct {
backoff time.Duration
lastUpdate time.Time
}
type Backoff struct {
sync.Mutex
Clock clock.Clock
defaultDuration time.Duration
maxDuration time.Duration
perItemBackoff map[string]*backoffEntry
}
func NewFakeBackOff(initial, max time.Duration, tc *clock.FakeClock) *Backoff {
return &Backoff{
perItemBackoff: map[string]*backoffEntry{},
Clock: tc,
defaultDuration: initial,
maxDuration: max,
}
}
func NewBackOff(initial, max time.Duration) *Backoff {
return &Backoff{
perItemBackoff: map[string]*backoffEntry{},
Clock: clock.RealClock{},
defaultDuration: initial,
maxDuration: max,
}
}
// Get the current backoff Duration
func (p *Backoff) Get(id string) time.Duration {
p.Lock()
defer p.Unlock()
var delay time.Duration
entry, ok := p.perItemBackoff[id]
if ok {
delay = entry.backoff
}
return delay
}
// move backoff to the next mark, capping at maxDuration
func (p *Backoff) Next(id string, eventTime time.Time) {
p.Lock()
defer p.Unlock()
entry, ok := p.perItemBackoff[id]
if !ok || hasExpired(eventTime, entry.lastUpdate, p.maxDuration) {
entry = p.initEntryUnsafe(id)
} else {
delay := entry.backoff * 2 // exponential
entry.backoff = time.Duration(integer.Int64Min(int64(delay), int64(p.maxDuration)))
}
entry.lastUpdate = p.Clock.Now()
}
// Reset forces clearing of all backoff data for a given key.
func (p *Backoff) Reset(id string) {
p.Lock()
defer p.Unlock()
delete(p.perItemBackoff, id)
}
// Returns True if the elapsed time since eventTime is smaller than the current backoff window
func (p *Backoff) IsInBackOffSince(id string, eventTime time.Time) bool {
p.Lock()
defer p.Unlock()
entry, ok := p.perItemBackoff[id]
if !ok {
return false
}
if hasExpired(eventTime, entry.lastUpdate, p.maxDuration) {
return false
}
return p.Clock.Now().Sub(eventTime) < entry.backoff
}
// Returns True if time since lastupdate is less than the current backoff window.
func (p *Backoff) IsInBackOffSinceUpdate(id string, eventTime time.Time) bool {
p.Lock()
defer p.Unlock()
entry, ok := p.perItemBackoff[id]
if !ok {
return false
}
if hasExpired(eventTime, entry.lastUpdate, p.maxDuration) {
return false
}
return eventTime.Sub(entry.lastUpdate) < entry.backoff
}
// Garbage collect records that have aged past maxDuration. Backoff users are expected
// to invoke this periodically.
func (p *Backoff) GC() {
p.Lock()
defer p.Unlock()
now := p.Clock.Now()
for id, entry := range p.perItemBackoff {
if now.Sub(entry.lastUpdate) > p.maxDuration*2 {
// GC when entry has not been updated for 2*maxDuration
delete(p.perItemBackoff, id)
}
}
}
func (p *Backoff) DeleteEntry(id string) {
p.Lock()
defer p.Unlock()
delete(p.perItemBackoff, id)
}
// Take a lock on *Backoff, before calling initEntryUnsafe
func (p *Backoff) initEntryUnsafe(id string) *backoffEntry {
entry := &backoffEntry{backoff: p.defaultDuration}
p.perItemBackoff[id] = entry
return entry
}
// After 2*maxDuration we restart the backoff factor to the beginning
func hasExpired(eventTime time.Time, lastUpdate time.Time, maxDuration time.Duration) bool {
return eventTime.Sub(lastUpdate) > maxDuration*2 // consider stable if it's ok for twice the maxDuration
}

View File

@ -1,195 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flowcontrol
import (
"testing"
"time"
"k8s.io/apimachinery/pkg/util/clock"
)
func TestSlowBackoff(t *testing.T) {
id := "_idSlow"
tc := clock.NewFakeClock(time.Now())
step := time.Second
maxDuration := 50 * step
b := NewFakeBackOff(step, maxDuration, tc)
cases := []time.Duration{0, 1, 2, 4, 8, 16, 32, 50, 50, 50}
for ix, c := range cases {
tc.Step(step)
w := b.Get(id)
if w != c*step {
t.Errorf("input: '%d': expected %s, got %s", ix, c*step, w)
}
b.Next(id, tc.Now())
}
//Now confirm that the Reset cancels backoff.
b.Next(id, tc.Now())
b.Reset(id)
if b.Get(id) != 0 {
t.Errorf("Reset didn't clear the backoff.")
}
}
func TestBackoffReset(t *testing.T) {
id := "_idReset"
tc := clock.NewFakeClock(time.Now())
step := time.Second
maxDuration := step * 5
b := NewFakeBackOff(step, maxDuration, tc)
startTime := tc.Now()
// get to backoff = maxDuration
for i := 0; i <= int(maxDuration/step); i++ {
tc.Step(step)
b.Next(id, tc.Now())
}
// backoff should be capped at maxDuration
if !b.IsInBackOffSince(id, tc.Now()) {
t.Errorf("expected to be in Backoff got %s", b.Get(id))
}
lastUpdate := tc.Now()
tc.Step(2*maxDuration + step) // time += 11s, 11 > 2*maxDuration
if b.IsInBackOffSince(id, lastUpdate) {
t.Errorf("expected to not be in Backoff after reset (start=%s, now=%s, lastUpdate=%s), got %s", startTime, tc.Now(), lastUpdate, b.Get(id))
}
}
func TestBackoffHightWaterMark(t *testing.T) {
id := "_idHiWaterMark"
tc := clock.NewFakeClock(time.Now())
step := time.Second
maxDuration := 5 * step
b := NewFakeBackOff(step, maxDuration, tc)
// get to backoff = maxDuration
for i := 0; i <= int(maxDuration/step); i++ {
tc.Step(step)
b.Next(id, tc.Now())
}
// backoff high watermark expires after 2*maxDuration
tc.Step(maxDuration + step)
b.Next(id, tc.Now())
if b.Get(id) != maxDuration {
t.Errorf("expected Backoff to stay at high watermark %s got %s", maxDuration, b.Get(id))
}
}
func TestBackoffGC(t *testing.T) {
id := "_idGC"
tc := clock.NewFakeClock(time.Now())
step := time.Second
maxDuration := 5 * step
b := NewFakeBackOff(step, maxDuration, tc)
for i := 0; i <= int(maxDuration/step); i++ {
tc.Step(step)
b.Next(id, tc.Now())
}
lastUpdate := tc.Now()
tc.Step(maxDuration + step)
b.GC()
_, found := b.perItemBackoff[id]
if !found {
t.Errorf("expected GC to skip entry, elapsed time=%s maxDuration=%s", tc.Now().Sub(lastUpdate), maxDuration)
}
tc.Step(maxDuration + step)
b.GC()
r, found := b.perItemBackoff[id]
if found {
t.Errorf("expected GC of entry after %s got entry %v", tc.Now().Sub(lastUpdate), r)
}
}
func TestIsInBackOffSinceUpdate(t *testing.T) {
id := "_idIsInBackOffSinceUpdate"
tc := clock.NewFakeClock(time.Now())
step := time.Second
maxDuration := 10 * step
b := NewFakeBackOff(step, maxDuration, tc)
startTime := tc.Now()
cases := []struct {
tick time.Duration
inBackOff bool
value int
}{
{tick: 0, inBackOff: false, value: 0},
{tick: 1, inBackOff: false, value: 1},
{tick: 2, inBackOff: true, value: 2},
{tick: 3, inBackOff: false, value: 2},
{tick: 4, inBackOff: true, value: 4},
{tick: 5, inBackOff: true, value: 4},
{tick: 6, inBackOff: true, value: 4},
{tick: 7, inBackOff: false, value: 4},
{tick: 8, inBackOff: true, value: 8},
{tick: 9, inBackOff: true, value: 8},
{tick: 10, inBackOff: true, value: 8},
{tick: 11, inBackOff: true, value: 8},
{tick: 12, inBackOff: true, value: 8},
{tick: 13, inBackOff: true, value: 8},
{tick: 14, inBackOff: true, value: 8},
{tick: 15, inBackOff: false, value: 8},
{tick: 16, inBackOff: true, value: 10},
{tick: 17, inBackOff: true, value: 10},
{tick: 18, inBackOff: true, value: 10},
{tick: 19, inBackOff: true, value: 10},
{tick: 20, inBackOff: true, value: 10},
{tick: 21, inBackOff: true, value: 10},
{tick: 22, inBackOff: true, value: 10},
{tick: 23, inBackOff: true, value: 10},
{tick: 24, inBackOff: true, value: 10},
{tick: 25, inBackOff: false, value: 10},
{tick: 26, inBackOff: true, value: 10},
{tick: 27, inBackOff: true, value: 10},
{tick: 28, inBackOff: true, value: 10},
{tick: 29, inBackOff: true, value: 10},
{tick: 30, inBackOff: true, value: 10},
{tick: 31, inBackOff: true, value: 10},
{tick: 32, inBackOff: true, value: 10},
{tick: 33, inBackOff: true, value: 10},
{tick: 34, inBackOff: true, value: 10},
{tick: 35, inBackOff: false, value: 10},
{tick: 56, inBackOff: false, value: 0},
{tick: 57, inBackOff: false, value: 1},
}
for _, c := range cases {
tc.SetTime(startTime.Add(c.tick * step))
if c.inBackOff != b.IsInBackOffSinceUpdate(id, tc.Now()) {
t.Errorf("expected IsInBackOffSinceUpdate %v got %v at tick %s", c.inBackOff, b.IsInBackOffSinceUpdate(id, tc.Now()), c.tick*step)
}
if c.inBackOff && (time.Duration(c.value)*step != b.Get(id)) {
t.Errorf("expected backoff value=%s got %s at tick %s", time.Duration(c.value)*step, b.Get(id), c.tick*step)
}
if !c.inBackOff {
b.Next(id, tc.Now())
}
}
}

View File

@ -1,148 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flowcontrol
import (
"sync"
"github.com/juju/ratelimit"
)
type RateLimiter interface {
// TryAccept returns true if a token is taken immediately. Otherwise,
// it returns false.
TryAccept() bool
// Accept returns once a token becomes available.
Accept()
// Stop stops the rate limiter, subsequent calls to CanAccept will return false
Stop()
// Saturation returns a percentage number which describes how saturated
// this rate limiter is.
// Usually we use token bucket rate limiter. In that case,
// 1.0 means no tokens are available; 0.0 means we have a full bucket of tokens to use.
Saturation() float64
// QPS returns QPS of this rate limiter
QPS() float32
}
type tokenBucketRateLimiter struct {
limiter *ratelimit.Bucket
qps float32
}
// NewTokenBucketRateLimiter creates a rate limiter which implements a token bucket approach.
// The rate limiter allows bursts of up to 'burst' to exceed the QPS, while still maintaining a
// smoothed qps rate of 'qps'.
// The bucket is initially filled with 'burst' tokens, and refills at a rate of 'qps'.
// The maximum number of tokens in the bucket is capped at 'burst'.
func NewTokenBucketRateLimiter(qps float32, burst int) RateLimiter {
limiter := ratelimit.NewBucketWithRate(float64(qps), int64(burst))
return newTokenBucketRateLimiter(limiter, qps)
}
// An injectable, mockable clock interface.
type Clock interface {
ratelimit.Clock
}
// NewTokenBucketRateLimiterWithClock is identical to NewTokenBucketRateLimiter
// but allows an injectable clock, for testing.
func NewTokenBucketRateLimiterWithClock(qps float32, burst int, clock Clock) RateLimiter {
limiter := ratelimit.NewBucketWithRateAndClock(float64(qps), int64(burst), clock)
return newTokenBucketRateLimiter(limiter, qps)
}
func newTokenBucketRateLimiter(limiter *ratelimit.Bucket, qps float32) RateLimiter {
return &tokenBucketRateLimiter{
limiter: limiter,
qps: qps,
}
}
func (t *tokenBucketRateLimiter) TryAccept() bool {
return t.limiter.TakeAvailable(1) == 1
}
func (t *tokenBucketRateLimiter) Saturation() float64 {
capacity := t.limiter.Capacity()
avail := t.limiter.Available()
return float64(capacity-avail) / float64(capacity)
}
// Accept will block until a token becomes available
func (t *tokenBucketRateLimiter) Accept() {
t.limiter.Wait(1)
}
func (t *tokenBucketRateLimiter) Stop() {
}
func (t *tokenBucketRateLimiter) QPS() float32 {
return t.qps
}
type fakeAlwaysRateLimiter struct{}
func NewFakeAlwaysRateLimiter() RateLimiter {
return &fakeAlwaysRateLimiter{}
}
func (t *fakeAlwaysRateLimiter) TryAccept() bool {
return true
}
func (t *fakeAlwaysRateLimiter) Saturation() float64 {
return 0
}
func (t *fakeAlwaysRateLimiter) Stop() {}
func (t *fakeAlwaysRateLimiter) Accept() {}
func (t *fakeAlwaysRateLimiter) QPS() float32 {
return 1
}
type fakeNeverRateLimiter struct {
wg sync.WaitGroup
}
func NewFakeNeverRateLimiter() RateLimiter {
rl := fakeNeverRateLimiter{}
rl.wg.Add(1)
return &rl
}
func (t *fakeNeverRateLimiter) TryAccept() bool {
return false
}
func (t *fakeNeverRateLimiter) Saturation() float64 {
return 1
}
func (t *fakeNeverRateLimiter) Stop() {
t.wg.Done()
}
func (t *fakeNeverRateLimiter) Accept() {
t.wg.Wait()
}
func (t *fakeNeverRateLimiter) QPS() float32 {
return 1
}

View File

@ -1,177 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flowcontrol
import (
"math"
"sync"
"testing"
"time"
)
func TestMultithreadedThrottling(t *testing.T) {
// Bucket with 100QPS and no burst
r := NewTokenBucketRateLimiter(100, 1)
// channel to collect 100 tokens
taken := make(chan bool, 100)
// Set up goroutines to hammer the throttler
startCh := make(chan bool)
endCh := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
// wait for the starting signal
<-startCh
for {
// get a token
r.Accept()
select {
// try to add it to the taken channel
case taken <- true:
continue
// if taken is full, notify and return
default:
endCh <- true
return
}
}
}()
}
// record wall time
startTime := time.Now()
// take the initial capacity so all tokens are the result of refill
r.Accept()
// start the thundering herd
close(startCh)
// wait for the first signal that we collected 100 tokens
<-endCh
// record wall time
endTime := time.Now()
// tolerate a 1% clock change because these things happen
if duration := endTime.Sub(startTime); duration < (time.Second * 99 / 100) {
// We shouldn't be able to get 100 tokens out of the bucket in less than 1 second of wall clock time, no matter what
t.Errorf("Expected it to take at least 1 second to get 100 tokens, took %v", duration)
} else {
t.Logf("Took %v to get 100 tokens", duration)
}
}
func TestBasicThrottle(t *testing.T) {
r := NewTokenBucketRateLimiter(1, 3)
for i := 0; i < 3; i++ {
if !r.TryAccept() {
t.Error("unexpected false accept")
}
}
if r.TryAccept() {
t.Error("unexpected true accept")
}
}
func TestIncrementThrottle(t *testing.T) {
r := NewTokenBucketRateLimiter(1, 1)
if !r.TryAccept() {
t.Error("unexpected false accept")
}
if r.TryAccept() {
t.Error("unexpected true accept")
}
// Allow to refill
time.Sleep(2 * time.Second)
if !r.TryAccept() {
t.Error("unexpected false accept")
}
}
func TestThrottle(t *testing.T) {
r := NewTokenBucketRateLimiter(10, 5)
// Should consume 5 tokens immediately, then
// the remaining 11 should take at least 1 second (0.1s each)
expectedFinish := time.Now().Add(time.Second * 1)
for i := 0; i < 16; i++ {
r.Accept()
}
if time.Now().Before(expectedFinish) {
t.Error("rate limit was not respected, finished too early")
}
}
func TestRateLimiterSaturation(t *testing.T) {
const e = 0.000001
tests := []struct {
capacity int
take int
expectedSaturation float64
}{
{1, 1, 1},
{10, 3, 0.3},
}
for i, tt := range tests {
rl := NewTokenBucketRateLimiter(1, tt.capacity)
for i := 0; i < tt.take; i++ {
rl.Accept()
}
if math.Abs(rl.Saturation()-tt.expectedSaturation) > e {
t.Fatalf("#%d: Saturation rate difference isn't within tolerable range\n want=%f, get=%f",
i, tt.expectedSaturation, rl.Saturation())
}
}
}
func TestAlwaysFake(t *testing.T) {
rl := NewFakeAlwaysRateLimiter()
if !rl.TryAccept() {
t.Error("TryAccept in AlwaysFake should return true.")
}
// If this will block the test will timeout
rl.Accept()
}
func TestNeverFake(t *testing.T) {
rl := NewFakeNeverRateLimiter()
if rl.TryAccept() {
t.Error("TryAccept in NeverFake should return false.")
}
finished := false
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
rl.Accept()
finished = true
wg.Done()
}()
// Wait some time to make sure it never finished.
time.Sleep(time.Second)
if finished {
t.Error("Accept should block forever in NeverFake.")
}
rl.Stop()
wg.Wait()
if !finished {
t.Error("Stop should make Accept unblock in NeverFake.")
}
}

View File

@ -1,25 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["homedir.go"],
importpath = "k8s.io/client-go/util/homedir",
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,47 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package homedir
import (
"os"
"runtime"
)
// HomeDir returns the home directory for the current user
func HomeDir() string {
if runtime.GOOS == "windows" {
// First prefer the HOME environmental variable
if home := os.Getenv("HOME"); len(home) > 0 {
if _, err := os.Stat(home); err == nil {
return home
}
}
if homeDrive, homePath := os.Getenv("HOMEDRIVE"), os.Getenv("HOMEPATH"); len(homeDrive) > 0 && len(homePath) > 0 {
homeDir := homeDrive + homePath
if _, err := os.Stat(homeDir); err == nil {
return homeDir
}
}
if userProfile := os.Getenv("USERPROFILE"); len(userProfile) > 0 {
if _, err := os.Stat(userProfile); err == nil {
return userProfile
}
}
}
return os.Getenv("HOME")
}

View File

@ -1,33 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = ["integer_test.go"],
importpath = "k8s.io/client-go/util/integer",
library = ":go_default_library",
)
go_library(
name = "go_default_library",
srcs = ["integer.go"],
importpath = "k8s.io/client-go/util/integer",
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,67 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package integer
func IntMax(a, b int) int {
if b > a {
return b
}
return a
}
func IntMin(a, b int) int {
if b < a {
return b
}
return a
}
func Int32Max(a, b int32) int32 {
if b > a {
return b
}
return a
}
func Int32Min(a, b int32) int32 {
if b < a {
return b
}
return a
}
func Int64Max(a, b int64) int64 {
if b > a {
return b
}
return a
}
func Int64Min(a, b int64) int64 {
if b < a {
return b
}
return a
}
// RoundToInt32 rounds floats into integer numbers.
func RoundToInt32(a float64) int32 {
if a < 0 {
return int32(a - 0.5)
}
return int32(a + 0.5)
}

View File

@ -1,244 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package integer
import "testing"
func TestIntMax(t *testing.T) {
tests := []struct {
nums []int
expectedMax int
}{
{
nums: []int{-1, 0},
expectedMax: 0,
},
{
nums: []int{-1, -2},
expectedMax: -1,
},
{
nums: []int{0, 1},
expectedMax: 1,
},
{
nums: []int{1, 2},
expectedMax: 2,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if max := IntMax(test.nums[0], test.nums[1]); max != test.expectedMax {
t.Errorf("expected %v, got %v", test.expectedMax, max)
}
}
}
func TestIntMin(t *testing.T) {
tests := []struct {
nums []int
expectedMin int
}{
{
nums: []int{-1, 0},
expectedMin: -1,
},
{
nums: []int{-1, -2},
expectedMin: -2,
},
{
nums: []int{0, 1},
expectedMin: 0,
},
{
nums: []int{1, 2},
expectedMin: 1,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if min := IntMin(test.nums[0], test.nums[1]); min != test.expectedMin {
t.Errorf("expected %v, got %v", test.expectedMin, min)
}
}
}
func TestInt32Max(t *testing.T) {
tests := []struct {
nums []int32
expectedMax int32
}{
{
nums: []int32{-1, 0},
expectedMax: 0,
},
{
nums: []int32{-1, -2},
expectedMax: -1,
},
{
nums: []int32{0, 1},
expectedMax: 1,
},
{
nums: []int32{1, 2},
expectedMax: 2,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if max := Int32Max(test.nums[0], test.nums[1]); max != test.expectedMax {
t.Errorf("expected %v, got %v", test.expectedMax, max)
}
}
}
func TestInt32Min(t *testing.T) {
tests := []struct {
nums []int32
expectedMin int32
}{
{
nums: []int32{-1, 0},
expectedMin: -1,
},
{
nums: []int32{-1, -2},
expectedMin: -2,
},
{
nums: []int32{0, 1},
expectedMin: 0,
},
{
nums: []int32{1, 2},
expectedMin: 1,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if min := Int32Min(test.nums[0], test.nums[1]); min != test.expectedMin {
t.Errorf("expected %v, got %v", test.expectedMin, min)
}
}
}
func TestInt64Max(t *testing.T) {
tests := []struct {
nums []int64
expectedMax int64
}{
{
nums: []int64{-1, 0},
expectedMax: 0,
},
{
nums: []int64{-1, -2},
expectedMax: -1,
},
{
nums: []int64{0, 1},
expectedMax: 1,
},
{
nums: []int64{1, 2},
expectedMax: 2,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if max := Int64Max(test.nums[0], test.nums[1]); max != test.expectedMax {
t.Errorf("expected %v, got %v", test.expectedMax, max)
}
}
}
func TestInt64Min(t *testing.T) {
tests := []struct {
nums []int64
expectedMin int64
}{
{
nums: []int64{-1, 0},
expectedMin: -1,
},
{
nums: []int64{-1, -2},
expectedMin: -2,
},
{
nums: []int64{0, 1},
expectedMin: 0,
},
{
nums: []int64{1, 2},
expectedMin: 1,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if min := Int64Min(test.nums[0], test.nums[1]); min != test.expectedMin {
t.Errorf("expected %v, got %v", test.expectedMin, min)
}
}
}
func TestRoundToInt32(t *testing.T) {
tests := []struct {
num float64
exp int32
}{
{
num: 5.5,
exp: 6,
},
{
num: -3.7,
exp: -4,
},
{
num: 3.49,
exp: 3,
},
{
num: -7.9,
exp: -8,
},
{
num: -4.499999,
exp: -4,
},
{
num: 0,
exp: 0,
},
}
for i, test := range tests {
t.Logf("executing scenario %d", i)
if got := RoundToInt32(test.num); got != test.exp {
t.Errorf("expected %d, got %d", test.exp, got)
}
}
}

View File

@ -1,42 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"jsonpath_test.go",
"parser_test.go",
],
importpath = "k8s.io/client-go/util/jsonpath",
library = ":go_default_library",
)
go_library(
name = "go_default_library",
srcs = [
"doc.go",
"jsonpath.go",
"node.go",
"parser.go",
],
importpath = "k8s.io/client-go/util/jsonpath",
deps = ["//vendor/k8s.io/client-go/third_party/forked/golang/template:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,20 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// package jsonpath is a template engine using jsonpath syntax,
// which can be seen at http://goessner.net/articles/JsonPath/.
// In addition, it has {range} {end} function to iterate list and slice.
package jsonpath // import "k8s.io/client-go/util/jsonpath"

View File

@ -1,517 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonpath
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
"k8s.io/client-go/third_party/forked/golang/template"
)
type JSONPath struct {
name string
parser *Parser
stack [][]reflect.Value // push and pop values in different scopes
cur []reflect.Value // current scope values
beginRange int
inRange int
endRange int
allowMissingKeys bool
}
// New creates a new JSONPath with the given name.
func New(name string) *JSONPath {
return &JSONPath{
name: name,
beginRange: 0,
inRange: 0,
endRange: 0,
}
}
// AllowMissingKeys allows a caller to specify whether they want an error if a field or map key
// cannot be located, or simply an empty result. The receiver is returned for chaining.
func (j *JSONPath) AllowMissingKeys(allow bool) *JSONPath {
j.allowMissingKeys = allow
return j
}
// Parse parses the given template and returns an error.
func (j *JSONPath) Parse(text string) error {
var err error
j.parser, err = Parse(j.name, text)
return err
}
// Execute bounds data into template and writes the result.
func (j *JSONPath) Execute(wr io.Writer, data interface{}) error {
fullResults, err := j.FindResults(data)
if err != nil {
return err
}
for ix := range fullResults {
if err := j.PrintResults(wr, fullResults[ix]); err != nil {
return err
}
}
return nil
}
func (j *JSONPath) FindResults(data interface{}) ([][]reflect.Value, error) {
if j.parser == nil {
return nil, fmt.Errorf("%s is an incomplete jsonpath template", j.name)
}
j.cur = []reflect.Value{reflect.ValueOf(data)}
nodes := j.parser.Root.Nodes
fullResult := [][]reflect.Value{}
for i := 0; i < len(nodes); i++ {
node := nodes[i]
results, err := j.walk(j.cur, node)
if err != nil {
return nil, err
}
// encounter an end node, break the current block
if j.endRange > 0 && j.endRange <= j.inRange {
j.endRange -= 1
break
}
// encounter a range node, start a range loop
if j.beginRange > 0 {
j.beginRange -= 1
j.inRange += 1
for k, value := range results {
j.parser.Root.Nodes = nodes[i+1:]
if k == len(results)-1 {
j.inRange -= 1
}
nextResults, err := j.FindResults(value.Interface())
if err != nil {
return nil, err
}
fullResult = append(fullResult, nextResults...)
}
break
}
fullResult = append(fullResult, results)
}
return fullResult, nil
}
// PrintResults writes the results into writer
func (j *JSONPath) PrintResults(wr io.Writer, results []reflect.Value) error {
for i, r := range results {
text, err := j.evalToText(r)
if err != nil {
return err
}
if i != len(results)-1 {
text = append(text, ' ')
}
if _, err = wr.Write(text); err != nil {
return err
}
}
return nil
}
// walk visits tree rooted at the given node in DFS order
func (j *JSONPath) walk(value []reflect.Value, node Node) ([]reflect.Value, error) {
switch node := node.(type) {
case *ListNode:
return j.evalList(value, node)
case *TextNode:
return []reflect.Value{reflect.ValueOf(node.Text)}, nil
case *FieldNode:
return j.evalField(value, node)
case *ArrayNode:
return j.evalArray(value, node)
case *FilterNode:
return j.evalFilter(value, node)
case *IntNode:
return j.evalInt(value, node)
case *BoolNode:
return j.evalBool(value, node)
case *FloatNode:
return j.evalFloat(value, node)
case *WildcardNode:
return j.evalWildcard(value, node)
case *RecursiveNode:
return j.evalRecursive(value, node)
case *UnionNode:
return j.evalUnion(value, node)
case *IdentifierNode:
return j.evalIdentifier(value, node)
default:
return value, fmt.Errorf("unexpected Node %v", node)
}
}
// evalInt evaluates IntNode
func (j *JSONPath) evalInt(input []reflect.Value, node *IntNode) ([]reflect.Value, error) {
result := make([]reflect.Value, len(input))
for i := range input {
result[i] = reflect.ValueOf(node.Value)
}
return result, nil
}
// evalFloat evaluates FloatNode
func (j *JSONPath) evalFloat(input []reflect.Value, node *FloatNode) ([]reflect.Value, error) {
result := make([]reflect.Value, len(input))
for i := range input {
result[i] = reflect.ValueOf(node.Value)
}
return result, nil
}
// evalBool evaluates BoolNode
func (j *JSONPath) evalBool(input []reflect.Value, node *BoolNode) ([]reflect.Value, error) {
result := make([]reflect.Value, len(input))
for i := range input {
result[i] = reflect.ValueOf(node.Value)
}
return result, nil
}
// evalList evaluates ListNode
func (j *JSONPath) evalList(value []reflect.Value, node *ListNode) ([]reflect.Value, error) {
var err error
curValue := value
for _, node := range node.Nodes {
curValue, err = j.walk(curValue, node)
if err != nil {
return curValue, err
}
}
return curValue, nil
}
// evalIdentifier evaluates IdentifierNode
func (j *JSONPath) evalIdentifier(input []reflect.Value, node *IdentifierNode) ([]reflect.Value, error) {
results := []reflect.Value{}
switch node.Name {
case "range":
j.stack = append(j.stack, j.cur)
j.beginRange += 1
results = input
case "end":
if j.endRange < j.inRange { // inside a loop, break the current block
j.endRange += 1
break
}
// the loop is about to end, pop value and continue the following execution
if len(j.stack) > 0 {
j.cur, j.stack = j.stack[len(j.stack)-1], j.stack[:len(j.stack)-1]
} else {
return results, fmt.Errorf("not in range, nothing to end")
}
default:
return input, fmt.Errorf("unrecognized identifier %v", node.Name)
}
return results, nil
}
// evalArray evaluates ArrayNode
func (j *JSONPath) evalArray(input []reflect.Value, node *ArrayNode) ([]reflect.Value, error) {
result := []reflect.Value{}
for _, value := range input {
value, isNil := template.Indirect(value)
if isNil {
continue
}
if value.Kind() != reflect.Array && value.Kind() != reflect.Slice {
return input, fmt.Errorf("%v is not array or slice", value.Type())
}
params := node.Params
if !params[0].Known {
params[0].Value = 0
}
if params[0].Value < 0 {
params[0].Value += value.Len()
}
if !params[1].Known {
params[1].Value = value.Len()
}
if params[1].Value < 0 {
params[1].Value += value.Len()
}
sliceLength := value.Len()
if params[1].Value != params[0].Value { // if you're requesting zero elements, allow it through.
if params[0].Value >= sliceLength || params[0].Value < 0 {
return input, fmt.Errorf("array index out of bounds: index %d, length %d", params[0].Value, sliceLength)
}
if params[1].Value > sliceLength || params[1].Value < 0 {
return input, fmt.Errorf("array index out of bounds: index %d, length %d", params[1].Value-1, sliceLength)
}
}
if !params[2].Known {
value = value.Slice(params[0].Value, params[1].Value)
} else {
value = value.Slice3(params[0].Value, params[1].Value, params[2].Value)
}
for i := 0; i < value.Len(); i++ {
result = append(result, value.Index(i))
}
}
return result, nil
}
// evalUnion evaluates UnionNode
func (j *JSONPath) evalUnion(input []reflect.Value, node *UnionNode) ([]reflect.Value, error) {
result := []reflect.Value{}
for _, listNode := range node.Nodes {
temp, err := j.evalList(input, listNode)
if err != nil {
return input, err
}
result = append(result, temp...)
}
return result, nil
}
func (j *JSONPath) findFieldInValue(value *reflect.Value, node *FieldNode) (reflect.Value, error) {
t := value.Type()
var inlineValue *reflect.Value
for ix := 0; ix < t.NumField(); ix++ {
f := t.Field(ix)
jsonTag := f.Tag.Get("json")
parts := strings.Split(jsonTag, ",")
if len(parts) == 0 {
continue
}
if parts[0] == node.Value {
return value.Field(ix), nil
}
if len(parts[0]) == 0 {
val := value.Field(ix)
inlineValue = &val
}
}
if inlineValue != nil {
if inlineValue.Kind() == reflect.Struct {
// handle 'inline'
match, err := j.findFieldInValue(inlineValue, node)
if err != nil {
return reflect.Value{}, err
}
if match.IsValid() {
return match, nil
}
}
}
return value.FieldByName(node.Value), nil
}
// evalField evaluates field of struct or key of map.
func (j *JSONPath) evalField(input []reflect.Value, node *FieldNode) ([]reflect.Value, error) {
results := []reflect.Value{}
// If there's no input, there's no output
if len(input) == 0 {
return results, nil
}
for _, value := range input {
var result reflect.Value
value, isNil := template.Indirect(value)
if isNil {
continue
}
if value.Kind() == reflect.Struct {
var err error
if result, err = j.findFieldInValue(&value, node); err != nil {
return nil, err
}
} else if value.Kind() == reflect.Map {
mapKeyType := value.Type().Key()
nodeValue := reflect.ValueOf(node.Value)
// node value type must be convertible to map key type
if !nodeValue.Type().ConvertibleTo(mapKeyType) {
return results, fmt.Errorf("%s is not convertible to %s", nodeValue, mapKeyType)
}
result = value.MapIndex(nodeValue.Convert(mapKeyType))
}
if result.IsValid() {
results = append(results, result)
}
}
if len(results) == 0 {
if j.allowMissingKeys {
return results, nil
}
return results, fmt.Errorf("%s is not found", node.Value)
}
return results, nil
}
// evalWildcard extracts all contents of the given value
func (j *JSONPath) evalWildcard(input []reflect.Value, node *WildcardNode) ([]reflect.Value, error) {
results := []reflect.Value{}
for _, value := range input {
value, isNil := template.Indirect(value)
if isNil {
continue
}
kind := value.Kind()
if kind == reflect.Struct {
for i := 0; i < value.NumField(); i++ {
results = append(results, value.Field(i))
}
} else if kind == reflect.Map {
for _, key := range value.MapKeys() {
results = append(results, value.MapIndex(key))
}
} else if kind == reflect.Array || kind == reflect.Slice || kind == reflect.String {
for i := 0; i < value.Len(); i++ {
results = append(results, value.Index(i))
}
}
}
return results, nil
}
// evalRecursive visits the given value recursively and pushes all of them to result
func (j *JSONPath) evalRecursive(input []reflect.Value, node *RecursiveNode) ([]reflect.Value, error) {
result := []reflect.Value{}
for _, value := range input {
results := []reflect.Value{}
value, isNil := template.Indirect(value)
if isNil {
continue
}
kind := value.Kind()
if kind == reflect.Struct {
for i := 0; i < value.NumField(); i++ {
results = append(results, value.Field(i))
}
} else if kind == reflect.Map {
for _, key := range value.MapKeys() {
results = append(results, value.MapIndex(key))
}
} else if kind == reflect.Array || kind == reflect.Slice || kind == reflect.String {
for i := 0; i < value.Len(); i++ {
results = append(results, value.Index(i))
}
}
if len(results) != 0 {
result = append(result, value)
output, err := j.evalRecursive(results, node)
if err != nil {
return result, err
}
result = append(result, output...)
}
}
return result, nil
}
// evalFilter filters array according to FilterNode
func (j *JSONPath) evalFilter(input []reflect.Value, node *FilterNode) ([]reflect.Value, error) {
results := []reflect.Value{}
for _, value := range input {
value, _ = template.Indirect(value)
if value.Kind() != reflect.Array && value.Kind() != reflect.Slice {
return input, fmt.Errorf("%v is not array or slice and cannot be filtered", value)
}
for i := 0; i < value.Len(); i++ {
temp := []reflect.Value{value.Index(i)}
lefts, err := j.evalList(temp, node.Left)
//case exists
if node.Operator == "exists" {
if len(lefts) > 0 {
results = append(results, value.Index(i))
}
continue
}
if err != nil {
return input, err
}
var left, right interface{}
switch {
case len(lefts) == 0:
continue
case len(lefts) > 1:
return input, fmt.Errorf("can only compare one element at a time")
}
left = lefts[0].Interface()
rights, err := j.evalList(temp, node.Right)
if err != nil {
return input, err
}
switch {
case len(rights) == 0:
continue
case len(rights) > 1:
return input, fmt.Errorf("can only compare one element at a time")
}
right = rights[0].Interface()
pass := false
switch node.Operator {
case "<":
pass, err = template.Less(left, right)
case ">":
pass, err = template.Greater(left, right)
case "==":
pass, err = template.Equal(left, right)
case "!=":
pass, err = template.NotEqual(left, right)
case "<=":
pass, err = template.LessEqual(left, right)
case ">=":
pass, err = template.GreaterEqual(left, right)
default:
return results, fmt.Errorf("unrecognized filter operator %s", node.Operator)
}
if err != nil {
return results, err
}
if pass {
results = append(results, value.Index(i))
}
}
}
return results, nil
}
// evalToText translates reflect value to corresponding text
func (j *JSONPath) evalToText(v reflect.Value) ([]byte, error) {
iface, ok := template.PrintableValue(v)
if !ok {
return nil, fmt.Errorf("can't print type %s", v.Type())
}
var buffer bytes.Buffer
fmt.Fprint(&buffer, iface)
return buffer.Bytes(), nil
}

View File

@ -1,371 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonpath
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"testing"
)
type jsonpathTest struct {
name string
template string
input interface{}
expect string
expectError bool
}
func testJSONPath(tests []jsonpathTest, allowMissingKeys bool, t *testing.T) {
for _, test := range tests {
j := New(test.name)
j.AllowMissingKeys(allowMissingKeys)
err := j.Parse(test.template)
if err != nil {
t.Errorf("in %s, parse %s error %v", test.name, test.template, err)
}
buf := new(bytes.Buffer)
err = j.Execute(buf, test.input)
if test.expectError {
if test.expectError && err == nil {
t.Errorf("in %s, expected execute error", test.name)
}
continue
} else if err != nil {
t.Errorf("in %s, execute error %v", test.name, err)
}
out := buf.String()
if out != test.expect {
t.Errorf(`in %s, expect to get "%s", got "%s"`, test.name, test.expect, out)
}
}
}
// testJSONPathSortOutput test cases related to map, the results may print in random order
func testJSONPathSortOutput(tests []jsonpathTest, t *testing.T) {
for _, test := range tests {
j := New(test.name)
err := j.Parse(test.template)
if err != nil {
t.Errorf("in %s, parse %s error %v", test.name, test.template, err)
}
buf := new(bytes.Buffer)
err = j.Execute(buf, test.input)
if err != nil {
t.Errorf("in %s, execute error %v", test.name, err)
}
out := buf.String()
//since map is visited in random order, we need to sort the results.
sortedOut := strings.Fields(out)
sort.Strings(sortedOut)
sortedExpect := strings.Fields(test.expect)
sort.Strings(sortedExpect)
if !reflect.DeepEqual(sortedOut, sortedExpect) {
t.Errorf(`in %s, expect to get "%s", got "%s"`, test.name, test.expect, out)
}
}
}
func testFailJSONPath(tests []jsonpathTest, t *testing.T) {
for _, test := range tests {
j := New(test.name)
err := j.Parse(test.template)
if err != nil {
t.Errorf("in %s, parse %s error %v", test.name, test.template, err)
}
buf := new(bytes.Buffer)
err = j.Execute(buf, test.input)
var out string
if err == nil {
out = "nil"
} else {
out = err.Error()
}
if out != test.expect {
t.Errorf("in %s, expect to get error %q, got %q", test.name, test.expect, out)
}
}
}
type book struct {
Category string
Author string
Title string
Price float32
}
func (b book) String() string {
return fmt.Sprintf("{Category: %s, Author: %s, Title: %s, Price: %v}", b.Category, b.Author, b.Title, b.Price)
}
type bicycle struct {
Color string
Price float32
IsNew bool
}
type empName string
type job string
type store struct {
Book []book
Bicycle []bicycle
Name string
Labels map[string]int
Employees map[empName]job
}
func TestStructInput(t *testing.T) {
storeData := store{
Name: "jsonpath",
Book: []book{
{"reference", "Nigel Rees", "Sayings of the Centurey", 8.95},
{"fiction", "Evelyn Waugh", "Sword of Honour", 12.99},
{"fiction", "Herman Melville", "Moby Dick", 8.99},
},
Bicycle: []bicycle{
{"red", 19.95, true},
{"green", 20.01, false},
},
Labels: map[string]int{
"engieer": 10,
"web/html": 15,
"k8s-app": 20,
},
Employees: map[empName]job{
"jason": "manager",
"dan": "clerk",
},
}
storeTests := []jsonpathTest{
{"plain", "hello jsonpath", nil, "hello jsonpath", false},
{"recursive", "{..}", []int{1, 2, 3}, "[1 2 3]", false},
{"filter", "{[?(@<5)]}", []int{2, 6, 3, 7}, "2 3", false},
{"quote", `{"{"}`, nil, "{", false},
{"union", "{[1,3,4]}", []int{0, 1, 2, 3, 4}, "1 3 4", false},
{"array", "{[0:2]}", []string{"Monday", "Tudesday"}, "Monday Tudesday", false},
{"variable", "hello {.Name}", storeData, "hello jsonpath", false},
{"dict/", "{$.Labels.web/html}", storeData, "15", false},
{"dict/", "{$.Employees.jason}", storeData, "manager", false},
{"dict/", "{$.Employees.dan}", storeData, "clerk", false},
{"dict-", "{.Labels.k8s-app}", storeData, "20", false},
{"nest", "{.Bicycle[*].Color}", storeData, "red green", false},
{"allarray", "{.Book[*].Author}", storeData, "Nigel Rees Evelyn Waugh Herman Melville", false},
{"allfileds", "{.Bicycle.*}", storeData, "{red 19.95 true} {green 20.01 false}", false},
{"recurfileds", "{..Price}", storeData, "8.95 12.99 8.99 19.95 20.01", false},
{"lastarray", "{.Book[-1:]}", storeData,
"{Category: fiction, Author: Herman Melville, Title: Moby Dick, Price: 8.99}", false},
{"recurarray", "{..Book[2]}", storeData,
"{Category: fiction, Author: Herman Melville, Title: Moby Dick, Price: 8.99}", false},
{"bool", "{.Bicycle[?(@.IsNew==true)]}", storeData, "{red 19.95 true}", false},
}
testJSONPath(storeTests, false, t)
missingKeyTests := []jsonpathTest{
{"nonexistent field", "{.hello}", storeData, "", false},
}
testJSONPath(missingKeyTests, true, t)
failStoreTests := []jsonpathTest{
{"invalid identifier", "{hello}", storeData, "unrecognized identifier hello", false},
{"nonexistent field", "{.hello}", storeData, "hello is not found", false},
{"invalid array", "{.Labels[0]}", storeData, "map[string]int is not array or slice", false},
{"invalid filter operator", "{.Book[?(@.Price<>10)]}", storeData, "unrecognized filter operator <>", false},
{"redundent end", "{range .Labels.*}{@}{end}{end}", storeData, "not in range, nothing to end", false},
}
testFailJSONPath(failStoreTests, t)
}
func TestJSONInput(t *testing.T) {
var pointsJSON = []byte(`[
{"id": "i1", "x":4, "y":-5},
{"id": "i2", "x":-2, "y":-5, "z":1},
{"id": "i3", "x": 8, "y": 3 },
{"id": "i4", "x": -6, "y": -1 },
{"id": "i5", "x": 0, "y": 2, "z": 1 },
{"id": "i6", "x": 1, "y": 4 }
]`)
var pointsData interface{}
err := json.Unmarshal(pointsJSON, &pointsData)
if err != nil {
t.Error(err)
}
pointsTests := []jsonpathTest{
{"exists filter", "{[?(@.z)].id}", pointsData, "i2 i5", false},
{"bracket key", "{[0]['id']}", pointsData, "i1", false},
}
testJSONPath(pointsTests, false, t)
}
// TestKubernetes tests some use cases from kubernetes
func TestKubernetes(t *testing.T) {
var input = []byte(`{
"kind": "List",
"items":[
{
"kind":"None",
"metadata":{
"name":"127.0.0.1",
"labels":{
"kubernetes.io/hostname":"127.0.0.1"
}
},
"status":{
"capacity":{"cpu":"4"},
"ready": true,
"addresses":[{"type": "LegacyHostIP", "address":"127.0.0.1"}]
}
},
{
"kind":"None",
"metadata":{
"name":"127.0.0.2",
"labels":{
"kubernetes.io/hostname":"127.0.0.2"
}
},
"status":{
"capacity":{"cpu":"8"},
"ready": false,
"addresses":[
{"type": "LegacyHostIP", "address":"127.0.0.2"},
{"type": "another", "address":"127.0.0.3"}
]
}
}
],
"users":[
{
"name": "myself",
"user": {}
},
{
"name": "e2e",
"user": {"username": "admin", "password": "secret"}
}
]
}`)
var nodesData interface{}
err := json.Unmarshal(input, &nodesData)
if err != nil {
t.Error(err)
}
nodesTests := []jsonpathTest{
{"range item", `{range .items[*]}{.metadata.name}, {end}{.kind}`, nodesData, "127.0.0.1, 127.0.0.2, List", false},
{"range item with quote", `{range .items[*]}{.metadata.name}{"\t"}{end}`, nodesData, "127.0.0.1\t127.0.0.2\t", false},
{"range addresss", `{.items[*].status.addresses[*].address}`, nodesData,
"127.0.0.1 127.0.0.2 127.0.0.3", false},
{"double range", `{range .items[*]}{range .status.addresses[*]}{.address}, {end}{end}`, nodesData,
"127.0.0.1, 127.0.0.2, 127.0.0.3, ", false},
{"item name", `{.items[*].metadata.name}`, nodesData, "127.0.0.1 127.0.0.2", false},
{"union nodes capacity", `{.items[*]['metadata.name', 'status.capacity']}`, nodesData,
"127.0.0.1 127.0.0.2 map[cpu:4] map[cpu:8]", false},
{"range nodes capacity", `{range .items[*]}[{.metadata.name}, {.status.capacity}] {end}`, nodesData,
"[127.0.0.1, map[cpu:4]] [127.0.0.2, map[cpu:8]] ", false},
{"user password", `{.users[?(@.name=="e2e")].user.password}`, &nodesData, "secret", false},
{"hostname", `{.items[0].metadata.labels.kubernetes\.io/hostname}`, &nodesData, "127.0.0.1", false},
{"hostname filter", `{.items[?(@.metadata.labels.kubernetes\.io/hostname=="127.0.0.1")].kind}`, &nodesData, "None", false},
{"bool item", `{.items[?(@..ready==true)].metadata.name}`, &nodesData, "127.0.0.1", false},
}
testJSONPath(nodesTests, false, t)
randomPrintOrderTests := []jsonpathTest{
{"recursive name", "{..name}", nodesData, `127.0.0.1 127.0.0.2 myself e2e`, false},
}
testJSONPathSortOutput(randomPrintOrderTests, t)
}
func TestFilterPartialMatchesSometimesMissingAnnotations(t *testing.T) {
// for https://issues.k8s.io/45546
var input = []byte(`{
"kind": "List",
"items": [
{
"kind": "Pod",
"metadata": {
"name": "pod1",
"annotations": {
"color": "blue"
}
}
},
{
"kind": "Pod",
"metadata": {
"name": "pod2"
}
},
{
"kind": "Pod",
"metadata": {
"name": "pod3",
"annotations": {
"color": "green"
}
}
},
{
"kind": "Pod",
"metadata": {
"name": "pod4",
"annotations": {
"color": "blue"
}
}
}
]
}`)
var data interface{}
err := json.Unmarshal(input, &data)
if err != nil {
t.Fatal(err)
}
testJSONPath(
[]jsonpathTest{
{
"filter, should only match a subset, some items don't have annotations, tolerate missing items",
`{.items[?(@.metadata.annotations.color=="blue")].metadata.name}`,
data,
"pod1 pod4",
false, // expect no error
},
},
true, // allow missing keys
t,
)
testJSONPath(
[]jsonpathTest{
{
"filter, should only match a subset, some items don't have annotations, error on missing items",
`{.items[?(@.metadata.annotations.color=="blue")].metadata.name}`,
data,
"",
true, // expect an error
},
},
false, // don't allow missing keys
t,
)
}

View File

@ -1,255 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonpath
import "fmt"
// NodeType identifies the type of a parse tree node.
type NodeType int
// Type returns itself and provides an easy default implementation
func (t NodeType) Type() NodeType {
return t
}
func (t NodeType) String() string {
return NodeTypeName[t]
}
const (
NodeText NodeType = iota
NodeArray
NodeList
NodeField
NodeIdentifier
NodeFilter
NodeInt
NodeFloat
NodeWildcard
NodeRecursive
NodeUnion
NodeBool
)
var NodeTypeName = map[NodeType]string{
NodeText: "NodeText",
NodeArray: "NodeArray",
NodeList: "NodeList",
NodeField: "NodeField",
NodeIdentifier: "NodeIdentifier",
NodeFilter: "NodeFilter",
NodeInt: "NodeInt",
NodeFloat: "NodeFloat",
NodeWildcard: "NodeWildcard",
NodeRecursive: "NodeRecursive",
NodeUnion: "NodeUnion",
NodeBool: "NodeBool",
}
type Node interface {
Type() NodeType
String() string
}
// ListNode holds a sequence of nodes.
type ListNode struct {
NodeType
Nodes []Node // The element nodes in lexical order.
}
func newList() *ListNode {
return &ListNode{NodeType: NodeList}
}
func (l *ListNode) append(n Node) {
l.Nodes = append(l.Nodes, n)
}
func (l *ListNode) String() string {
return fmt.Sprintf("%s", l.Type())
}
// TextNode holds plain text.
type TextNode struct {
NodeType
Text string // The text; may span newlines.
}
func newText(text string) *TextNode {
return &TextNode{NodeType: NodeText, Text: text}
}
func (t *TextNode) String() string {
return fmt.Sprintf("%s: %s", t.Type(), t.Text)
}
// FieldNode holds field of struct
type FieldNode struct {
NodeType
Value string
}
func newField(value string) *FieldNode {
return &FieldNode{NodeType: NodeField, Value: value}
}
func (f *FieldNode) String() string {
return fmt.Sprintf("%s: %s", f.Type(), f.Value)
}
// IdentifierNode holds an identifier
type IdentifierNode struct {
NodeType
Name string
}
func newIdentifier(value string) *IdentifierNode {
return &IdentifierNode{
NodeType: NodeIdentifier,
Name: value,
}
}
func (f *IdentifierNode) String() string {
return fmt.Sprintf("%s: %s", f.Type(), f.Name)
}
// ParamsEntry holds param information for ArrayNode
type ParamsEntry struct {
Value int
Known bool // whether the value is known when parse it
}
// ArrayNode holds start, end, step information for array index selection
type ArrayNode struct {
NodeType
Params [3]ParamsEntry // start, end, step
}
func newArray(params [3]ParamsEntry) *ArrayNode {
return &ArrayNode{
NodeType: NodeArray,
Params: params,
}
}
func (a *ArrayNode) String() string {
return fmt.Sprintf("%s: %v", a.Type(), a.Params)
}
// FilterNode holds operand and operator information for filter
type FilterNode struct {
NodeType
Left *ListNode
Right *ListNode
Operator string
}
func newFilter(left, right *ListNode, operator string) *FilterNode {
return &FilterNode{
NodeType: NodeFilter,
Left: left,
Right: right,
Operator: operator,
}
}
func (f *FilterNode) String() string {
return fmt.Sprintf("%s: %s %s %s", f.Type(), f.Left, f.Operator, f.Right)
}
// IntNode holds integer value
type IntNode struct {
NodeType
Value int
}
func newInt(num int) *IntNode {
return &IntNode{NodeType: NodeInt, Value: num}
}
func (i *IntNode) String() string {
return fmt.Sprintf("%s: %d", i.Type(), i.Value)
}
// FloatNode holds float value
type FloatNode struct {
NodeType
Value float64
}
func newFloat(num float64) *FloatNode {
return &FloatNode{NodeType: NodeFloat, Value: num}
}
func (i *FloatNode) String() string {
return fmt.Sprintf("%s: %f", i.Type(), i.Value)
}
// WildcardNode means a wildcard
type WildcardNode struct {
NodeType
}
func newWildcard() *WildcardNode {
return &WildcardNode{NodeType: NodeWildcard}
}
func (i *WildcardNode) String() string {
return fmt.Sprintf("%s", i.Type())
}
// RecursiveNode means a recursive descent operator
type RecursiveNode struct {
NodeType
}
func newRecursive() *RecursiveNode {
return &RecursiveNode{NodeType: NodeRecursive}
}
func (r *RecursiveNode) String() string {
return fmt.Sprintf("%s", r.Type())
}
// UnionNode is union of ListNode
type UnionNode struct {
NodeType
Nodes []*ListNode
}
func newUnion(nodes []*ListNode) *UnionNode {
return &UnionNode{NodeType: NodeUnion, Nodes: nodes}
}
func (u *UnionNode) String() string {
return fmt.Sprintf("%s", u.Type())
}
// BoolNode holds bool value
type BoolNode struct {
NodeType
Value bool
}
func newBool(value bool) *BoolNode {
return &BoolNode{NodeType: NodeBool, Value: value}
}
func (b *BoolNode) String() string {
return fmt.Sprintf("%s: %t", b.Type(), b.Value)
}

View File

@ -1,525 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonpath
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
const eof = -1
const (
leftDelim = "{"
rightDelim = "}"
)
type Parser struct {
Name string
Root *ListNode
input string
cur *ListNode
pos int
start int
width int
}
var (
ErrSyntax = errors.New("invalid syntax")
dictKeyRex = regexp.MustCompile(`^'([^']*)'$`)
sliceOperatorRex = regexp.MustCompile(`^(-?[\d]*)(:-?[\d]*)?(:[\d]*)?$`)
)
// Parse parsed the given text and return a node Parser.
// If an error is encountered, parsing stops and an empty
// Parser is returned with the error
func Parse(name, text string) (*Parser, error) {
p := NewParser(name)
err := p.Parse(text)
if err != nil {
p = nil
}
return p, err
}
func NewParser(name string) *Parser {
return &Parser{
Name: name,
}
}
// parseAction parsed the expression inside delimiter
func parseAction(name, text string) (*Parser, error) {
p, err := Parse(name, fmt.Sprintf("%s%s%s", leftDelim, text, rightDelim))
// when error happens, p will be nil, so we need to return here
if err != nil {
return p, err
}
p.Root = p.Root.Nodes[0].(*ListNode)
return p, nil
}
func (p *Parser) Parse(text string) error {
p.input = text
p.Root = newList()
p.pos = 0
return p.parseText(p.Root)
}
// consumeText return the parsed text since last cosumeText
func (p *Parser) consumeText() string {
value := p.input[p.start:p.pos]
p.start = p.pos
return value
}
// next returns the next rune in the input.
func (p *Parser) next() rune {
if int(p.pos) >= len(p.input) {
p.width = 0
return eof
}
r, w := utf8.DecodeRuneInString(p.input[p.pos:])
p.width = w
p.pos += p.width
return r
}
// peek returns but does not consume the next rune in the input.
func (p *Parser) peek() rune {
r := p.next()
p.backup()
return r
}
// backup steps back one rune. Can only be called once per call of next.
func (p *Parser) backup() {
p.pos -= p.width
}
func (p *Parser) parseText(cur *ListNode) error {
for {
if strings.HasPrefix(p.input[p.pos:], leftDelim) {
if p.pos > p.start {
cur.append(newText(p.consumeText()))
}
return p.parseLeftDelim(cur)
}
if p.next() == eof {
break
}
}
// Correctly reached EOF.
if p.pos > p.start {
cur.append(newText(p.consumeText()))
}
return nil
}
// parseLeftDelim scans the left delimiter, which is known to be present.
func (p *Parser) parseLeftDelim(cur *ListNode) error {
p.pos += len(leftDelim)
p.consumeText()
newNode := newList()
cur.append(newNode)
cur = newNode
return p.parseInsideAction(cur)
}
func (p *Parser) parseInsideAction(cur *ListNode) error {
prefixMap := map[string]func(*ListNode) error{
rightDelim: p.parseRightDelim,
"[?(": p.parseFilter,
"..": p.parseRecursive,
}
for prefix, parseFunc := range prefixMap {
if strings.HasPrefix(p.input[p.pos:], prefix) {
return parseFunc(cur)
}
}
switch r := p.next(); {
case r == eof || isEndOfLine(r):
return fmt.Errorf("unclosed action")
case r == ' ':
p.consumeText()
case r == '@' || r == '$': //the current object, just pass it
p.consumeText()
case r == '[':
return p.parseArray(cur)
case r == '"' || r == '\'':
return p.parseQuote(cur, r)
case r == '.':
return p.parseField(cur)
case r == '+' || r == '-' || unicode.IsDigit(r):
p.backup()
return p.parseNumber(cur)
case isAlphaNumeric(r):
p.backup()
return p.parseIdentifier(cur)
default:
return fmt.Errorf("unrecognized character in action: %#U", r)
}
return p.parseInsideAction(cur)
}
// parseRightDelim scans the right delimiter, which is known to be present.
func (p *Parser) parseRightDelim(cur *ListNode) error {
p.pos += len(rightDelim)
p.consumeText()
cur = p.Root
return p.parseText(cur)
}
// parseIdentifier scans build-in keywords, like "range" "end"
func (p *Parser) parseIdentifier(cur *ListNode) error {
var r rune
for {
r = p.next()
if isTerminator(r) {
p.backup()
break
}
}
value := p.consumeText()
if isBool(value) {
v, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("can not parse bool '%s': %s", value, err.Error())
}
cur.append(newBool(v))
} else {
cur.append(newIdentifier(value))
}
return p.parseInsideAction(cur)
}
// parseRecursive scans the recursive desent operator ..
func (p *Parser) parseRecursive(cur *ListNode) error {
p.pos += len("..")
p.consumeText()
cur.append(newRecursive())
if r := p.peek(); isAlphaNumeric(r) {
return p.parseField(cur)
}
return p.parseInsideAction(cur)
}
// parseNumber scans number
func (p *Parser) parseNumber(cur *ListNode) error {
r := p.peek()
if r == '+' || r == '-' {
r = p.next()
}
for {
r = p.next()
if r != '.' && !unicode.IsDigit(r) {
p.backup()
break
}
}
value := p.consumeText()
i, err := strconv.Atoi(value)
if err == nil {
cur.append(newInt(i))
return p.parseInsideAction(cur)
}
d, err := strconv.ParseFloat(value, 64)
if err == nil {
cur.append(newFloat(d))
return p.parseInsideAction(cur)
}
return fmt.Errorf("cannot parse number %s", value)
}
// parseArray scans array index selection
func (p *Parser) parseArray(cur *ListNode) error {
Loop:
for {
switch p.next() {
case eof, '\n':
return fmt.Errorf("unterminated array")
case ']':
break Loop
}
}
text := p.consumeText()
text = string(text[1 : len(text)-1])
if text == "*" {
text = ":"
}
//union operator
strs := strings.Split(text, ",")
if len(strs) > 1 {
union := []*ListNode{}
for _, str := range strs {
parser, err := parseAction("union", fmt.Sprintf("[%s]", strings.Trim(str, " ")))
if err != nil {
return err
}
union = append(union, parser.Root)
}
cur.append(newUnion(union))
return p.parseInsideAction(cur)
}
// dict key
value := dictKeyRex.FindStringSubmatch(text)
if value != nil {
parser, err := parseAction("arraydict", fmt.Sprintf(".%s", value[1]))
if err != nil {
return err
}
for _, node := range parser.Root.Nodes {
cur.append(node)
}
return p.parseInsideAction(cur)
}
//slice operator
value = sliceOperatorRex.FindStringSubmatch(text)
if value == nil {
return fmt.Errorf("invalid array index %s", text)
}
value = value[1:]
params := [3]ParamsEntry{}
for i := 0; i < 3; i++ {
if value[i] != "" {
if i > 0 {
value[i] = value[i][1:]
}
if i > 0 && value[i] == "" {
params[i].Known = false
} else {
var err error
params[i].Known = true
params[i].Value, err = strconv.Atoi(value[i])
if err != nil {
return fmt.Errorf("array index %s is not a number", value[i])
}
}
} else {
if i == 1 {
params[i].Known = true
params[i].Value = params[0].Value + 1
} else {
params[i].Known = false
params[i].Value = 0
}
}
}
cur.append(newArray(params))
return p.parseInsideAction(cur)
}
// parseFilter scans filter inside array selection
func (p *Parser) parseFilter(cur *ListNode) error {
p.pos += len("[?(")
p.consumeText()
begin := false
end := false
var pair rune
Loop:
for {
r := p.next()
switch r {
case eof, '\n':
return fmt.Errorf("unterminated filter")
case '"', '\'':
if begin == false {
//save the paired rune
begin = true
pair = r
continue
}
//only add when met paired rune
if p.input[p.pos-2] != '\\' && r == pair {
end = true
}
case ')':
//in rightParser below quotes only appear zero or once
//and must be paired at the beginning and end
if begin == end {
break Loop
}
}
}
if p.next() != ']' {
return fmt.Errorf("unclosed array expect ]")
}
reg := regexp.MustCompile(`^([^!<>=]+)([!<>=]+)(.+?)$`)
text := p.consumeText()
text = string(text[:len(text)-2])
value := reg.FindStringSubmatch(text)
if value == nil {
parser, err := parseAction("text", text)
if err != nil {
return err
}
cur.append(newFilter(parser.Root, newList(), "exists"))
} else {
leftParser, err := parseAction("left", value[1])
if err != nil {
return err
}
rightParser, err := parseAction("right", value[3])
if err != nil {
return err
}
cur.append(newFilter(leftParser.Root, rightParser.Root, value[2]))
}
return p.parseInsideAction(cur)
}
// parseQuote unquotes string inside double or single quote
func (p *Parser) parseQuote(cur *ListNode, end rune) error {
Loop:
for {
switch p.next() {
case eof, '\n':
return fmt.Errorf("unterminated quoted string")
case end:
//if it's not escape break the Loop
if p.input[p.pos-2] != '\\' {
break Loop
}
}
}
value := p.consumeText()
s, err := UnquoteExtend(value)
if err != nil {
return fmt.Errorf("unquote string %s error %v", value, err)
}
cur.append(newText(s))
return p.parseInsideAction(cur)
}
// parseField scans a field until a terminator
func (p *Parser) parseField(cur *ListNode) error {
p.consumeText()
for p.advance() {
}
value := p.consumeText()
if value == "*" {
cur.append(newWildcard())
} else {
cur.append(newField(strings.Replace(value, "\\", "", -1)))
}
return p.parseInsideAction(cur)
}
// advance scans until next non-escaped terminator
func (p *Parser) advance() bool {
r := p.next()
if r == '\\' {
p.next()
} else if isTerminator(r) {
p.backup()
return false
}
return true
}
// isTerminator reports whether the input is at valid termination character to appear after an identifier.
func isTerminator(r rune) bool {
if isSpace(r) || isEndOfLine(r) {
return true
}
switch r {
case eof, '.', ',', '[', ']', '$', '@', '{', '}':
return true
}
return false
}
// isSpace reports whether r is a space character.
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
// isEndOfLine reports whether r is an end-of-line character.
func isEndOfLine(r rune) bool {
return r == '\r' || r == '\n'
}
// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore.
func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
}
// isBool reports whether s is a boolean value.
func isBool(s string) bool {
return s == "true" || s == "false"
}
//UnquoteExtend is almost same as strconv.Unquote(), but it support parse single quotes as a string
func UnquoteExtend(s string) (string, error) {
n := len(s)
if n < 2 {
return "", ErrSyntax
}
quote := s[0]
if quote != s[n-1] {
return "", ErrSyntax
}
s = s[1 : n-1]
if quote != '"' && quote != '\'' {
return "", ErrSyntax
}
// Is it trivial? Avoid allocation.
if !contains(s, '\\') && !contains(s, quote) {
return s, nil
}
var runeTmp [utf8.UTFMax]byte
buf := make([]byte, 0, 3*len(s)/2) // Try to avoid more allocations.
for len(s) > 0 {
c, multibyte, ss, err := strconv.UnquoteChar(s, quote)
if err != nil {
return "", err
}
s = ss
if c < utf8.RuneSelf || !multibyte {
buf = append(buf, byte(c))
} else {
n := utf8.EncodeRune(runeTmp[:], c)
buf = append(buf, runeTmp[:n]...)
}
}
return string(buf), nil
}
func contains(s string, c byte) bool {
for i := 0; i < len(s); i++ {
if s[i] == c {
return true
}
}
return false
}

View File

@ -1,152 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonpath
import (
"testing"
)
type parserTest struct {
name string
text string
nodes []Node
shouldError bool
}
var parserTests = []parserTest{
{"plain", `hello jsonpath`, []Node{newText("hello jsonpath")}, false},
{"variable", `hello {.jsonpath}`,
[]Node{newText("hello "), newList(), newField("jsonpath")}, false},
{"arrayfiled", `hello {['jsonpath']}`,
[]Node{newText("hello "), newList(), newField("jsonpath")}, false},
{"quote", `{"{"}`, []Node{newList(), newText("{")}, false},
{"array", `{[1:3]}`, []Node{newList(),
newArray([3]ParamsEntry{{1, true}, {3, true}, {0, false}})}, false},
{"allarray", `{.book[*].author}`,
[]Node{newList(), newField("book"),
newArray([3]ParamsEntry{{0, false}, {0, false}, {0, false}}), newField("author")}, false},
{"wildcard", `{.bicycle.*}`,
[]Node{newList(), newField("bicycle"), newWildcard()}, false},
{"filter", `{[?(@.price<3)]}`,
[]Node{newList(), newFilter(newList(), newList(), "<"),
newList(), newField("price"), newList(), newInt(3)}, false},
{"recursive", `{..}`, []Node{newList(), newRecursive()}, false},
{"recurField", `{..price}`,
[]Node{newList(), newRecursive(), newField("price")}, false},
{"arraydict", `{['book.price']}`, []Node{newList(),
newField("book"), newField("price"),
}, false},
{"union", `{['bicycle.price', 3, 'book.price']}`, []Node{newList(), newUnion([]*ListNode{}),
newList(), newField("bicycle"), newField("price"),
newList(), newArray([3]ParamsEntry{{3, true}, {4, true}, {0, false}}),
newList(), newField("book"), newField("price"),
}, false},
{"range", `{range .items}{.name},{end}`, []Node{
newList(), newIdentifier("range"), newField("items"),
newList(), newField("name"), newText(","),
newList(), newIdentifier("end"),
}, false},
{"malformat input", `{\\\}`, []Node{}, true},
{"paired parentheses in quotes", `{[?(@.status.nodeInfo.osImage == "()")]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("()")}, false},
{"paired parentheses in double quotes and with double quotes escape", `{[?(@.status.nodeInfo.osImage == "(\"\")")]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("(\"\")")}, false},
{"unregular parentheses in double quotes", `{[?(@.test == "())(")]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("test"), newList(), newText("())(")}, false},
{"plain text in single quotes", `{[?(@.status.nodeInfo.osImage == 'Linux')]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("Linux")}, false},
{"test filter suffix", `{[?(@.status.nodeInfo.osImage == "{[()]}")]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("{[()]}")}, false},
{"double inside single", `{[?(@.status.nodeInfo.osImage == "''")]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("''")}, false},
{"single inside double", `{[?(@.status.nodeInfo.osImage == '""')]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("\"\"")}, false},
{"single containing escaped single", `{[?(@.status.nodeInfo.osImage == '\\\'')]}`,
[]Node{newList(), newFilter(newList(), newList(), "=="), newList(), newField("status"), newField("nodeInfo"), newField("osImage"), newList(), newText("\\'")}, false},
}
func collectNode(nodes []Node, cur Node) []Node {
nodes = append(nodes, cur)
switch cur.Type() {
case NodeList:
for _, node := range cur.(*ListNode).Nodes {
nodes = collectNode(nodes, node)
}
case NodeFilter:
nodes = collectNode(nodes, cur.(*FilterNode).Left)
nodes = collectNode(nodes, cur.(*FilterNode).Right)
case NodeUnion:
for _, node := range cur.(*UnionNode).Nodes {
nodes = collectNode(nodes, node)
}
}
return nodes
}
func TestParser(t *testing.T) {
for _, test := range parserTests {
parser, err := Parse(test.name, test.text)
if test.shouldError {
if err == nil {
t.Errorf("unexpected non-error when parsing %s", test.name)
}
continue
}
if err != nil {
t.Errorf("parse %s error %v", test.name, err)
}
result := collectNode([]Node{}, parser.Root)[1:]
if len(result) != len(test.nodes) {
t.Errorf("in %s, expect to get %d nodes, got %d nodes", test.name, len(test.nodes), len(result))
t.Error(result)
}
for i, expect := range test.nodes {
if result[i].String() != expect.String() {
t.Errorf("in %s, %dth node, expect %v, got %v", test.name, i, expect, result[i])
}
}
}
}
type failParserTest struct {
name string
text string
err string
}
func TestFailParser(t *testing.T) {
failParserTests := []failParserTest{
{"unclosed action", "{.hello", "unclosed action"},
{"unrecognized character", "{*}", "unrecognized character in action: U+002A '*'"},
{"invalid number", "{+12.3.0}", "cannot parse number +12.3.0"},
{"unterminated array", "{[1}", "unterminated array"},
{"invalid index", "{[::-1]}", "invalid array index ::-1"},
{"unterminated filter", "{[?(.price]}", "unterminated filter"},
}
for _, test := range failParserTests {
_, err := Parse(test.name, test.text)
var out string
if err == nil {
out = "nil"
} else {
out = err.Error()
}
if out != test.err {
t.Errorf("in %s, expect to get error %v, got %v", test.name, test.err, out)
}
}
}

View File

@ -1,42 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = ["util.go"],
importpath = "k8s.io/client-go/util/retry",
deps = [
"//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["util_test.go"],
importpath = "k8s.io/client-go/util/retry",
library = ":go_default_library",
deps = [
"//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/runtime/schema:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,2 +0,0 @@
reviewers:
- caesarxuchao

View File

@ -1,79 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package retry
import (
"time"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/wait"
)
// DefaultRetry is the recommended retry for a conflict where multiple clients
// are making changes to the same resource.
var DefaultRetry = wait.Backoff{
Steps: 5,
Duration: 10 * time.Millisecond,
Factor: 1.0,
Jitter: 0.1,
}
// DefaultBackoff is the recommended backoff for a conflict where a client
// may be attempting to make an unrelated modification to a resource under
// active management by one or more controllers.
var DefaultBackoff = wait.Backoff{
Steps: 4,
Duration: 10 * time.Millisecond,
Factor: 5.0,
Jitter: 0.1,
}
// RetryConflict executes the provided function repeatedly, retrying if the server returns a conflicting
// write. Callers should preserve previous executions if they wish to retry changes. It performs an
// exponential backoff.
//
// var pod *api.Pod
// err := RetryOnConflict(DefaultBackoff, func() (err error) {
// pod, err = c.Pods("mynamespace").UpdateStatus(podStatus)
// return
// })
// if err != nil {
// // may be conflict if max retries were hit
// return err
// }
// ...
//
// TODO: Make Backoff an interface?
func RetryOnConflict(backoff wait.Backoff, fn func() error) error {
var lastConflictErr error
err := wait.ExponentialBackoff(backoff, func() (bool, error) {
err := fn()
switch {
case err == nil:
return true, nil
case errors.IsConflict(err):
lastConflictErr = err
return false, nil
default:
return false, err
}
})
if err == wait.ErrWaitTimeout {
err = lastConflictErr
}
return err
}

View File

@ -1,71 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package retry
import (
"fmt"
"testing"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/wait"
)
func TestRetryOnConflict(t *testing.T) {
opts := wait.Backoff{Factor: 1.0, Steps: 3}
conflictErr := errors.NewConflict(schema.GroupResource{Resource: "test"}, "other", nil)
// never returns
err := RetryOnConflict(opts, func() error {
return conflictErr
})
if err != conflictErr {
t.Errorf("unexpected error: %v", err)
}
// returns immediately
i := 0
err = RetryOnConflict(opts, func() error {
i++
return nil
})
if err != nil || i != 1 {
t.Errorf("unexpected error: %v", err)
}
// returns immediately on error
testErr := fmt.Errorf("some other error")
err = RetryOnConflict(opts, func() error {
return testErr
})
if err != testErr {
t.Errorf("unexpected error: %v", err)
}
// keeps retrying
i = 0
err = RetryOnConflict(opts, func() error {
if i < 2 {
i++
return errors.NewConflict(schema.GroupResource{Resource: "test"}, "other", nil)
}
return nil
})
if err != nil || i != 2 {
t.Errorf("unexpected error: %v", err)
}
}

View File

@ -1,36 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = ["fake_handler_test.go"],
importpath = "k8s.io/client-go/util/testing",
library = ":go_default_library",
)
go_library(
name = "go_default_library",
srcs = [
"fake_handler.go",
"tmpdir.go",
],
importpath = "k8s.io/client-go/util/testing",
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,139 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testing
import (
"io/ioutil"
"net/http"
"net/url"
"reflect"
"sync"
)
// TestInterface is a simple interface providing Errorf, to make injection for
// testing easier (insert 'yo dawg' meme here).
type TestInterface interface {
Errorf(format string, args ...interface{})
Logf(format string, args ...interface{})
}
// LogInterface is a simple interface to allow injection of Logf to report serving errors.
type LogInterface interface {
Logf(format string, args ...interface{})
}
// FakeHandler is to assist in testing HTTP requests. Notice that FakeHandler is
// not thread safe and you must not direct traffic to except for the request
// you want to test. You can do this by hiding it in an http.ServeMux.
type FakeHandler struct {
RequestReceived *http.Request
RequestBody string
StatusCode int
ResponseBody string
// For logging - you can use a *testing.T
// This will keep log messages associated with the test.
T LogInterface
// Enforce "only one use" constraint.
lock sync.Mutex
requestCount int
hasBeenChecked bool
SkipRequestFn func(verb string, url url.URL) bool
}
func (f *FakeHandler) SetResponseBody(responseBody string) {
f.lock.Lock()
defer f.lock.Unlock()
f.ResponseBody = responseBody
}
func (f *FakeHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {
f.lock.Lock()
defer f.lock.Unlock()
if f.SkipRequestFn != nil && f.SkipRequestFn(request.Method, *request.URL) {
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(f.StatusCode)
response.Write([]byte(f.ResponseBody))
return
}
f.requestCount++
if f.hasBeenChecked {
panic("got request after having been validated")
}
f.RequestReceived = request
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(f.StatusCode)
response.Write([]byte(f.ResponseBody))
bodyReceived, err := ioutil.ReadAll(request.Body)
if err != nil && f.T != nil {
f.T.Logf("Received read error: %v", err)
}
f.RequestBody = string(bodyReceived)
if f.T != nil {
f.T.Logf("request body: %s", f.RequestBody)
}
}
func (f *FakeHandler) ValidateRequestCount(t TestInterface, count int) bool {
ok := true
f.lock.Lock()
defer f.lock.Unlock()
if f.requestCount != count {
ok = false
t.Errorf("Expected %d call, but got %d. Only the last call is recorded and checked.", count, f.requestCount)
}
f.hasBeenChecked = true
return ok
}
// ValidateRequest verifies that FakeHandler received a request with expected path, method, and body.
func (f *FakeHandler) ValidateRequest(t TestInterface, expectedPath, expectedMethod string, body *string) {
f.lock.Lock()
defer f.lock.Unlock()
if f.requestCount != 1 {
t.Logf("Expected 1 call, but got %v. Only the last call is recorded and checked.", f.requestCount)
}
f.hasBeenChecked = true
expectURL, err := url.Parse(expectedPath)
if err != nil {
t.Errorf("Couldn't parse %v as a URL.", expectedPath)
}
if f.RequestReceived == nil {
t.Errorf("Unexpected nil request received for %s", expectedPath)
return
}
if f.RequestReceived.URL.Path != expectURL.Path {
t.Errorf("Unexpected request path for request %#v, received: %q, expected: %q", f.RequestReceived, f.RequestReceived.URL.Path, expectURL.Path)
}
if e, a := expectURL.Query(), f.RequestReceived.URL.Query(); !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected query for request %#v, received: %q, expected: %q", f.RequestReceived, a, e)
}
if f.RequestReceived.Method != expectedMethod {
t.Errorf("Unexpected method: %q, expected: %q", f.RequestReceived.Method, expectedMethod)
}
if body != nil {
if *body != f.RequestBody {
t.Errorf("Received body:\n%s\n Doesn't match expected body:\n%s", f.RequestBody, *body)
}
}
}

View File

@ -1,180 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testing
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
func TestFakeHandlerPath(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
body := "somebody"
req, err := http.NewRequest(method, server.URL+path, bytes.NewBufferString(body))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler.ValidateRequest(t, path, method, &body)
}
func TestFakeHandlerPathNoBody(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
req, err := http.NewRequest(method, server.URL+path, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler.ValidateRequest(t, path, method, nil)
}
type fakeError struct {
errors []string
}
func (f *fakeError) Errorf(format string, args ...interface{}) {
f.errors = append(f.errors, format)
}
func (f *fakeError) Logf(format string, args ...interface{}) {}
func TestFakeHandlerWrongPath(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
fakeT := fakeError{}
req, err := http.NewRequest(method, server.URL+"/foo/baz", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler.ValidateRequest(&fakeT, path, method, nil)
if len(fakeT.errors) != 1 {
t.Errorf("Unexpected error set: %#v", fakeT.errors)
}
}
func TestFakeHandlerWrongMethod(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
fakeT := fakeError{}
req, err := http.NewRequest("PUT", server.URL+path, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler.ValidateRequest(&fakeT, path, method, nil)
if len(fakeT.errors) != 1 {
t.Errorf("Unexpected error set: %#v", fakeT.errors)
}
}
func TestFakeHandlerWrongBody(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
body := "somebody"
fakeT := fakeError{}
req, err := http.NewRequest(method, server.URL+path, bytes.NewBufferString(body))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
otherbody := "otherbody"
handler.ValidateRequest(&fakeT, path, method, &otherbody)
if len(fakeT.errors) != 1 {
t.Errorf("Unexpected error set: %#v", fakeT.errors)
}
}
func TestFakeHandlerNilBody(t *testing.T) {
handler := FakeHandler{StatusCode: http.StatusOK}
server := httptest.NewServer(&handler)
defer server.Close()
method := "GET"
path := "/foo/bar"
body := "somebody"
fakeT := fakeError{}
req, err := http.NewRequest(method, server.URL+path, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
client := http.Client{}
_, err = client.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
handler.ValidateRequest(&fakeT, path, method, &body)
if len(fakeT.errors) != 1 {
t.Errorf("Unexpected error set: %#v", fakeT.errors)
}
}

View File

@ -1,44 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testing
import (
"io/ioutil"
"os"
)
// MkTmpdir creates a temporary directory based upon the prefix passed in.
// If successful, it returns the temporary directory path. The directory can be
// deleted with a call to "os.RemoveAll(...)".
// In case of error, it'll return an empty string and the error.
func MkTmpdir(prefix string) (string, error) {
tmpDir, err := ioutil.TempDir(os.TempDir(), prefix)
if err != nil {
return "", err
}
return tmpDir, nil
}
// MkTmpdir does the same work as "MkTmpdir", except in case of
// errors, it'll trigger a panic.
func MkTmpdirOrDie(prefix string) string {
tmpDir, err := MkTmpdir(prefix)
if err != nil {
panic(err)
}
return tmpDir
}

View File

@ -1,61 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"default_rate_limiters_test.go",
"delaying_queue_test.go",
"rate_limitting_queue_test.go",
],
importpath = "k8s.io/client-go/util/workqueue",
library = ":go_default_library",
deps = [
"//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/wait:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"default_rate_limiters.go",
"delaying_queue.go",
"doc.go",
"metrics.go",
"parallelizer.go",
"queue.go",
"rate_limitting_queue.go",
],
importpath = "k8s.io/client-go/util/workqueue",
deps = [
"//vendor/github.com/juju/ratelimit:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/runtime:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["queue_test.go"],
importpath = "k8s.io/client-go/util/workqueue_test",
deps = ["//vendor/k8s.io/client-go/util/workqueue:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -1,211 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"math"
"sync"
"time"
"github.com/juju/ratelimit"
)
type RateLimiter interface {
// When gets an item and gets to decide how long that item should wait
When(item interface{}) time.Duration
// Forget indicates that an item is finished being retried. Doesn't matter whether its for perm failing
// or for success, we'll stop tracking it
Forget(item interface{})
// NumRequeues returns back how many failures the item has had
NumRequeues(item interface{}) int
}
// DefaultControllerRateLimiter is a no-arg constructor for a default rate limiter for a workqueue. It has
// both overall and per-item rate limitting. The overall is a token bucket and the per-item is exponential
func DefaultControllerRateLimiter() RateLimiter {
return NewMaxOfRateLimiter(
NewItemExponentialFailureRateLimiter(5*time.Millisecond, 1000*time.Second),
// 10 qps, 100 bucket size. This is only for retry speed and its only the overall factor (not per item)
&BucketRateLimiter{Bucket: ratelimit.NewBucketWithRate(float64(10), int64(100))},
)
}
// BucketRateLimiter adapts a standard bucket to the workqueue ratelimiter API
type BucketRateLimiter struct {
*ratelimit.Bucket
}
var _ RateLimiter = &BucketRateLimiter{}
func (r *BucketRateLimiter) When(item interface{}) time.Duration {
return r.Bucket.Take(1)
}
func (r *BucketRateLimiter) NumRequeues(item interface{}) int {
return 0
}
func (r *BucketRateLimiter) Forget(item interface{}) {
}
// ItemExponentialFailureRateLimiter does a simple baseDelay*10^<num-failures> limit
// dealing with max failures and expiration are up to the caller
type ItemExponentialFailureRateLimiter struct {
failuresLock sync.Mutex
failures map[interface{}]int
baseDelay time.Duration
maxDelay time.Duration
}
var _ RateLimiter = &ItemExponentialFailureRateLimiter{}
func NewItemExponentialFailureRateLimiter(baseDelay time.Duration, maxDelay time.Duration) RateLimiter {
return &ItemExponentialFailureRateLimiter{
failures: map[interface{}]int{},
baseDelay: baseDelay,
maxDelay: maxDelay,
}
}
func DefaultItemBasedRateLimiter() RateLimiter {
return NewItemExponentialFailureRateLimiter(time.Millisecond, 1000*time.Second)
}
func (r *ItemExponentialFailureRateLimiter) When(item interface{}) time.Duration {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
exp := r.failures[item]
r.failures[item] = r.failures[item] + 1
// The backoff is capped such that 'calculated' value never overflows.
backoff := float64(r.baseDelay.Nanoseconds()) * math.Pow(2, float64(exp))
if backoff > math.MaxInt64 {
return r.maxDelay
}
calculated := time.Duration(backoff)
if calculated > r.maxDelay {
return r.maxDelay
}
return calculated
}
func (r *ItemExponentialFailureRateLimiter) NumRequeues(item interface{}) int {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
return r.failures[item]
}
func (r *ItemExponentialFailureRateLimiter) Forget(item interface{}) {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
delete(r.failures, item)
}
// ItemFastSlowRateLimiter does a quick retry for a certain number of attempts, then a slow retry after that
type ItemFastSlowRateLimiter struct {
failuresLock sync.Mutex
failures map[interface{}]int
maxFastAttempts int
fastDelay time.Duration
slowDelay time.Duration
}
var _ RateLimiter = &ItemFastSlowRateLimiter{}
func NewItemFastSlowRateLimiter(fastDelay, slowDelay time.Duration, maxFastAttempts int) RateLimiter {
return &ItemFastSlowRateLimiter{
failures: map[interface{}]int{},
fastDelay: fastDelay,
slowDelay: slowDelay,
maxFastAttempts: maxFastAttempts,
}
}
func (r *ItemFastSlowRateLimiter) When(item interface{}) time.Duration {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
r.failures[item] = r.failures[item] + 1
if r.failures[item] <= r.maxFastAttempts {
return r.fastDelay
}
return r.slowDelay
}
func (r *ItemFastSlowRateLimiter) NumRequeues(item interface{}) int {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
return r.failures[item]
}
func (r *ItemFastSlowRateLimiter) Forget(item interface{}) {
r.failuresLock.Lock()
defer r.failuresLock.Unlock()
delete(r.failures, item)
}
// MaxOfRateLimiter calls every RateLimiter and returns the worst case response
// When used with a token bucket limiter, the burst could be apparently exceeded in cases where particular items
// were separately delayed a longer time.
type MaxOfRateLimiter struct {
limiters []RateLimiter
}
func (r *MaxOfRateLimiter) When(item interface{}) time.Duration {
ret := time.Duration(0)
for _, limiter := range r.limiters {
curr := limiter.When(item)
if curr > ret {
ret = curr
}
}
return ret
}
func NewMaxOfRateLimiter(limiters ...RateLimiter) RateLimiter {
return &MaxOfRateLimiter{limiters: limiters}
}
func (r *MaxOfRateLimiter) NumRequeues(item interface{}) int {
ret := 0
for _, limiter := range r.limiters {
curr := limiter.NumRequeues(item)
if curr > ret {
ret = curr
}
}
return ret
}
func (r *MaxOfRateLimiter) Forget(item interface{}) {
for _, limiter := range r.limiters {
limiter.Forget(item)
}
}

View File

@ -1,184 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"testing"
"time"
)
func TestItemExponentialFailureRateLimiter(t *testing.T) {
limiter := NewItemExponentialFailureRateLimiter(1*time.Millisecond, 1*time.Second)
if e, a := 1*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 4*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 8*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 16*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 1*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2, limiter.NumRequeues("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
limiter.Forget("one")
if e, a := 0, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 1*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestItemExponentialFailureRateLimiterOverFlow(t *testing.T) {
limiter := NewItemExponentialFailureRateLimiter(1*time.Millisecond, 1000*time.Second)
for i := 0; i < 5; i++ {
limiter.When("one")
}
if e, a := 32*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
for i := 0; i < 1000; i++ {
limiter.When("overflow1")
}
if e, a := 1000*time.Second, limiter.When("overflow1"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
limiter = NewItemExponentialFailureRateLimiter(1*time.Minute, 1000*time.Hour)
for i := 0; i < 2; i++ {
limiter.When("two")
}
if e, a := 4*time.Minute, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
for i := 0; i < 1000; i++ {
limiter.When("overflow2")
}
if e, a := 1000*time.Hour, limiter.When("overflow2"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestItemFastSlowRateLimiter(t *testing.T) {
limiter := NewItemFastSlowRateLimiter(5*time.Millisecond, 10*time.Second, 3)
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 10*time.Second, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 10*time.Second, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2, limiter.NumRequeues("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
limiter.Forget("one")
if e, a := 0, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestMaxOfRateLimiter(t *testing.T) {
limiter := NewMaxOfRateLimiter(
NewItemFastSlowRateLimiter(5*time.Millisecond, 3*time.Second, 3),
NewItemExponentialFailureRateLimiter(1*time.Millisecond, 1*time.Second),
)
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 3*time.Second, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 3*time.Second, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2, limiter.NumRequeues("two"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
limiter.Forget("one")
if e, a := 0, limiter.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 5*time.Millisecond, limiter.When("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
}

View File

@ -1,257 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"container/heap"
"time"
"k8s.io/apimachinery/pkg/util/clock"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
)
// DelayingInterface is an Interface that can Add an item at a later time. This makes it easier to
// requeue items after failures without ending up in a hot-loop.
type DelayingInterface interface {
Interface
// AddAfter adds an item to the workqueue after the indicated duration has passed
AddAfter(item interface{}, duration time.Duration)
}
// NewDelayingQueue constructs a new workqueue with delayed queuing ability
func NewDelayingQueue() DelayingInterface {
return newDelayingQueue(clock.RealClock{}, "")
}
func NewNamedDelayingQueue(name string) DelayingInterface {
return newDelayingQueue(clock.RealClock{}, name)
}
func newDelayingQueue(clock clock.Clock, name string) DelayingInterface {
ret := &delayingType{
Interface: NewNamed(name),
clock: clock,
heartbeat: clock.Tick(maxWait),
stopCh: make(chan struct{}),
waitingForAddCh: make(chan *waitFor, 1000),
metrics: newRetryMetrics(name),
}
go ret.waitingLoop()
return ret
}
// delayingType wraps an Interface and provides delayed re-enquing
type delayingType struct {
Interface
// clock tracks time for delayed firing
clock clock.Clock
// stopCh lets us signal a shutdown to the waiting loop
stopCh chan struct{}
// heartbeat ensures we wait no more than maxWait before firing
//
// TODO: replace with Ticker (and add to clock) so this can be cleaned up.
// clock.Tick will leak.
heartbeat <-chan time.Time
// waitingForAddCh is a buffered channel that feeds waitingForAdd
waitingForAddCh chan *waitFor
// metrics counts the number of retries
metrics retryMetrics
}
// waitFor holds the data to add and the time it should be added
type waitFor struct {
data t
readyAt time.Time
// index in the priority queue (heap)
index int
}
// waitForPriorityQueue implements a priority queue for waitFor items.
//
// waitForPriorityQueue implements heap.Interface. The item occuring next in
// time (i.e., the item with the smallest readyAt) is at the root (index 0).
// Peek returns this minimum item at index 0. Pop returns the minimum item after
// it has been removed from the queue and placed at index Len()-1 by
// container/heap. Push adds an item at index Len(), and container/heap
// percolates it into the correct location.
type waitForPriorityQueue []*waitFor
func (pq waitForPriorityQueue) Len() int {
return len(pq)
}
func (pq waitForPriorityQueue) Less(i, j int) bool {
return pq[i].readyAt.Before(pq[j].readyAt)
}
func (pq waitForPriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
// Push adds an item to the queue. Push should not be called directly; instead,
// use `heap.Push`.
func (pq *waitForPriorityQueue) Push(x interface{}) {
n := len(*pq)
item := x.(*waitFor)
item.index = n
*pq = append(*pq, item)
}
// Pop removes an item from the queue. Pop should not be called directly;
// instead, use `heap.Pop`.
func (pq *waitForPriorityQueue) Pop() interface{} {
n := len(*pq)
item := (*pq)[n-1]
item.index = -1
*pq = (*pq)[0:(n - 1)]
return item
}
// Peek returns the item at the beginning of the queue, without removing the
// item or otherwise mutating the queue. It is safe to call directly.
func (pq waitForPriorityQueue) Peek() interface{} {
return pq[0]
}
// ShutDown gives a way to shut off this queue
func (q *delayingType) ShutDown() {
q.Interface.ShutDown()
close(q.stopCh)
}
// AddAfter adds the given item to the work queue after the given delay
func (q *delayingType) AddAfter(item interface{}, duration time.Duration) {
// don't add if we're already shutting down
if q.ShuttingDown() {
return
}
q.metrics.retry()
// immediately add things with no delay
if duration <= 0 {
q.Add(item)
return
}
select {
case <-q.stopCh:
// unblock if ShutDown() is called
case q.waitingForAddCh <- &waitFor{data: item, readyAt: q.clock.Now().Add(duration)}:
}
}
// maxWait keeps a max bound on the wait time. It's just insurance against weird things happening.
// Checking the queue every 10 seconds isn't expensive and we know that we'll never end up with an
// expired item sitting for more than 10 seconds.
const maxWait = 10 * time.Second
// waitingLoop runs until the workqueue is shutdown and keeps a check on the list of items to be added.
func (q *delayingType) waitingLoop() {
defer utilruntime.HandleCrash()
// Make a placeholder channel to use when there are no items in our list
never := make(<-chan time.Time)
waitingForQueue := &waitForPriorityQueue{}
heap.Init(waitingForQueue)
waitingEntryByData := map[t]*waitFor{}
for {
if q.Interface.ShuttingDown() {
return
}
now := q.clock.Now()
// Add ready entries
for waitingForQueue.Len() > 0 {
entry := waitingForQueue.Peek().(*waitFor)
if entry.readyAt.After(now) {
break
}
entry = heap.Pop(waitingForQueue).(*waitFor)
q.Add(entry.data)
delete(waitingEntryByData, entry.data)
}
// Set up a wait for the first item's readyAt (if one exists)
nextReadyAt := never
if waitingForQueue.Len() > 0 {
entry := waitingForQueue.Peek().(*waitFor)
nextReadyAt = q.clock.After(entry.readyAt.Sub(now))
}
select {
case <-q.stopCh:
return
case <-q.heartbeat:
// continue the loop, which will add ready items
case <-nextReadyAt:
// continue the loop, which will add ready items
case waitEntry := <-q.waitingForAddCh:
if waitEntry.readyAt.After(q.clock.Now()) {
insert(waitingForQueue, waitingEntryByData, waitEntry)
} else {
q.Add(waitEntry.data)
}
drained := false
for !drained {
select {
case waitEntry := <-q.waitingForAddCh:
if waitEntry.readyAt.After(q.clock.Now()) {
insert(waitingForQueue, waitingEntryByData, waitEntry)
} else {
q.Add(waitEntry.data)
}
default:
drained = true
}
}
}
}
}
// insert adds the entry to the priority queue, or updates the readyAt if it already exists in the queue
func insert(q *waitForPriorityQueue, knownEntries map[t]*waitFor, entry *waitFor) {
// if the entry already exists, update the time only if it would cause the item to be queued sooner
existing, exists := knownEntries[entry.data]
if exists {
if existing.readyAt.After(entry.readyAt) {
existing.readyAt = entry.readyAt
heap.Fix(q, existing.index)
}
return
}
heap.Push(q, entry)
knownEntries[entry.data] = entry
}

View File

@ -1,255 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"fmt"
"math/rand"
"reflect"
"testing"
"time"
"k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/wait"
)
func TestSimpleQueue(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
q := newDelayingQueue(fakeClock, "")
first := "foo"
q.AddAfter(first, 50*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
fakeClock.Step(60 * time.Millisecond)
if err := waitForAdded(q, 1); err != nil {
t.Errorf("should have added")
}
item, _ := q.Get()
q.Done(item)
// step past the next heartbeat
fakeClock.Step(10 * time.Second)
err := wait.Poll(1*time.Millisecond, 30*time.Millisecond, func() (done bool, err error) {
if q.Len() > 0 {
return false, fmt.Errorf("added to queue")
}
return false, nil
})
if err != wait.ErrWaitTimeout {
t.Errorf("expected timeout, got: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
}
func TestDeduping(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
q := newDelayingQueue(fakeClock, "")
first := "foo"
q.AddAfter(first, 50*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
q.AddAfter(first, 70*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
// step past the first block, we should receive now
fakeClock.Step(60 * time.Millisecond)
if err := waitForAdded(q, 1); err != nil {
t.Errorf("should have added")
}
item, _ := q.Get()
q.Done(item)
// step past the second add
fakeClock.Step(20 * time.Millisecond)
if q.Len() != 0 {
t.Errorf("should not have added")
}
// test again, but this time the earlier should override
q.AddAfter(first, 50*time.Millisecond)
q.AddAfter(first, 30*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
fakeClock.Step(40 * time.Millisecond)
if err := waitForAdded(q, 1); err != nil {
t.Errorf("should have added")
}
item, _ = q.Get()
q.Done(item)
// step past the second add
fakeClock.Step(20 * time.Millisecond)
if q.Len() != 0 {
t.Errorf("should not have added")
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
}
func TestAddTwoFireEarly(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
q := newDelayingQueue(fakeClock, "")
first := "foo"
second := "bar"
third := "baz"
q.AddAfter(first, 1*time.Second)
q.AddAfter(second, 50*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
fakeClock.Step(60 * time.Millisecond)
if err := waitForAdded(q, 1); err != nil {
t.Fatalf("unexpected err: %v", err)
}
item, _ := q.Get()
if !reflect.DeepEqual(item, second) {
t.Errorf("expected %v, got %v", second, item)
}
q.AddAfter(third, 2*time.Second)
fakeClock.Step(1 * time.Second)
if err := waitForAdded(q, 1); err != nil {
t.Fatalf("unexpected err: %v", err)
}
item, _ = q.Get()
if !reflect.DeepEqual(item, first) {
t.Errorf("expected %v, got %v", first, item)
}
fakeClock.Step(2 * time.Second)
if err := waitForAdded(q, 1); err != nil {
t.Fatalf("unexpected err: %v", err)
}
item, _ = q.Get()
if !reflect.DeepEqual(item, third) {
t.Errorf("expected %v, got %v", third, item)
}
}
func TestCopyShifting(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
q := newDelayingQueue(fakeClock, "")
first := "foo"
second := "bar"
third := "baz"
q.AddAfter(first, 1*time.Second)
q.AddAfter(second, 500*time.Millisecond)
q.AddAfter(third, 250*time.Millisecond)
if err := waitForWaitingQueueToFill(q); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if q.Len() != 0 {
t.Errorf("should not have added")
}
fakeClock.Step(2 * time.Second)
if err := waitForAdded(q, 3); err != nil {
t.Fatalf("unexpected err: %v", err)
}
actualFirst, _ := q.Get()
if !reflect.DeepEqual(actualFirst, third) {
t.Errorf("expected %v, got %v", third, actualFirst)
}
actualSecond, _ := q.Get()
if !reflect.DeepEqual(actualSecond, second) {
t.Errorf("expected %v, got %v", second, actualSecond)
}
actualThird, _ := q.Get()
if !reflect.DeepEqual(actualThird, first) {
t.Errorf("expected %v, got %v", first, actualThird)
}
}
func BenchmarkDelayingQueue_AddAfter(b *testing.B) {
r := rand.New(rand.NewSource(time.Now().Unix()))
fakeClock := clock.NewFakeClock(time.Now())
q := newDelayingQueue(fakeClock, "")
// Add items
for n := 0; n < b.N; n++ {
data := fmt.Sprintf("%d", n)
q.AddAfter(data, time.Duration(r.Int63n(int64(10*time.Minute))))
}
// Exercise item removal as well
fakeClock.Step(11 * time.Minute)
for n := 0; n < b.N; n++ {
_, _ = q.Get()
}
}
func waitForAdded(q DelayingInterface, depth int) error {
return wait.Poll(1*time.Millisecond, 10*time.Second, func() (done bool, err error) {
if q.Len() == depth {
return true, nil
}
return false, nil
})
}
func waitForWaitingQueueToFill(q DelayingInterface) error {
return wait.Poll(1*time.Millisecond, 10*time.Second, func() (done bool, err error) {
if len(q.(*delayingType).waitingForAddCh) == 0 {
return true, nil
}
return false, nil
})
}

View File

@ -1,26 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package workqueue provides a simple queue that supports the following
// features:
// * Fair: items processed in the order in which they are added.
// * Stingy: a single item will not be processed multiple times concurrently,
// and if an item is added multiple times before it can be processed, it
// will only be processed once.
// * Multiple consumers and producers. In particular, it is allowed for an
// item to be reenqueued while it is being processed.
// * Shutdown notifications.
package workqueue

View File

@ -1,195 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"sync"
"time"
)
// This file provides abstractions for setting the provider (e.g., prometheus)
// of metrics.
type queueMetrics interface {
add(item t)
get(item t)
done(item t)
}
// GaugeMetric represents a single numerical value that can arbitrarily go up
// and down.
type GaugeMetric interface {
Inc()
Dec()
}
// CounterMetric represents a single numerical value that only ever
// goes up.
type CounterMetric interface {
Inc()
}
// SummaryMetric captures individual observations.
type SummaryMetric interface {
Observe(float64)
}
type noopMetric struct{}
func (noopMetric) Inc() {}
func (noopMetric) Dec() {}
func (noopMetric) Observe(float64) {}
type defaultQueueMetrics struct {
// current depth of a workqueue
depth GaugeMetric
// total number of adds handled by a workqueue
adds CounterMetric
// how long an item stays in a workqueue
latency SummaryMetric
// how long processing an item from a workqueue takes
workDuration SummaryMetric
addTimes map[t]time.Time
processingStartTimes map[t]time.Time
}
func (m *defaultQueueMetrics) add(item t) {
if m == nil {
return
}
m.adds.Inc()
m.depth.Inc()
if _, exists := m.addTimes[item]; !exists {
m.addTimes[item] = time.Now()
}
}
func (m *defaultQueueMetrics) get(item t) {
if m == nil {
return
}
m.depth.Dec()
m.processingStartTimes[item] = time.Now()
if startTime, exists := m.addTimes[item]; exists {
m.latency.Observe(sinceInMicroseconds(startTime))
delete(m.addTimes, item)
}
}
func (m *defaultQueueMetrics) done(item t) {
if m == nil {
return
}
if startTime, exists := m.processingStartTimes[item]; exists {
m.workDuration.Observe(sinceInMicroseconds(startTime))
delete(m.processingStartTimes, item)
}
}
// Gets the time since the specified start in microseconds.
func sinceInMicroseconds(start time.Time) float64 {
return float64(time.Since(start).Nanoseconds() / time.Microsecond.Nanoseconds())
}
type retryMetrics interface {
retry()
}
type defaultRetryMetrics struct {
retries CounterMetric
}
func (m *defaultRetryMetrics) retry() {
if m == nil {
return
}
m.retries.Inc()
}
// MetricsProvider generates various metrics used by the queue.
type MetricsProvider interface {
NewDepthMetric(name string) GaugeMetric
NewAddsMetric(name string) CounterMetric
NewLatencyMetric(name string) SummaryMetric
NewWorkDurationMetric(name string) SummaryMetric
NewRetriesMetric(name string) CounterMetric
}
type noopMetricsProvider struct{}
func (_ noopMetricsProvider) NewDepthMetric(name string) GaugeMetric {
return noopMetric{}
}
func (_ noopMetricsProvider) NewAddsMetric(name string) CounterMetric {
return noopMetric{}
}
func (_ noopMetricsProvider) NewLatencyMetric(name string) SummaryMetric {
return noopMetric{}
}
func (_ noopMetricsProvider) NewWorkDurationMetric(name string) SummaryMetric {
return noopMetric{}
}
func (_ noopMetricsProvider) NewRetriesMetric(name string) CounterMetric {
return noopMetric{}
}
var metricsFactory = struct {
metricsProvider MetricsProvider
setProviders sync.Once
}{
metricsProvider: noopMetricsProvider{},
}
func newQueueMetrics(name string) queueMetrics {
var ret *defaultQueueMetrics
if len(name) == 0 {
return ret
}
return &defaultQueueMetrics{
depth: metricsFactory.metricsProvider.NewDepthMetric(name),
adds: metricsFactory.metricsProvider.NewAddsMetric(name),
latency: metricsFactory.metricsProvider.NewLatencyMetric(name),
workDuration: metricsFactory.metricsProvider.NewWorkDurationMetric(name),
addTimes: map[t]time.Time{},
processingStartTimes: map[t]time.Time{},
}
}
func newRetryMetrics(name string) retryMetrics {
var ret *defaultRetryMetrics
if len(name) == 0 {
return ret
}
return &defaultRetryMetrics{
retries: metricsFactory.metricsProvider.NewRetriesMetric(name),
}
}
// SetProvider sets the metrics provider of the metricsFactory.
func SetProvider(metricsProvider MetricsProvider) {
metricsFactory.setProviders.Do(func() {
metricsFactory.metricsProvider = metricsProvider
})
}

View File

@ -1,52 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"sync"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
)
type DoWorkPieceFunc func(piece int)
// Parallelize is a very simple framework that allow for parallelizing
// N independent pieces of work.
func Parallelize(workers, pieces int, doWorkPiece DoWorkPieceFunc) {
toProcess := make(chan int, pieces)
for i := 0; i < pieces; i++ {
toProcess <- i
}
close(toProcess)
if pieces < workers {
workers = pieces
}
wg := sync.WaitGroup{}
wg.Add(workers)
for i := 0; i < workers; i++ {
go func() {
defer utilruntime.HandleCrash()
defer wg.Done()
for piece := range toProcess {
doWorkPiece(piece)
}
}()
}
wg.Wait()
}

View File

@ -1,172 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"sync"
)
type Interface interface {
Add(item interface{})
Len() int
Get() (item interface{}, shutdown bool)
Done(item interface{})
ShutDown()
ShuttingDown() bool
}
// New constructs a new work queue (see the package comment).
func New() *Type {
return NewNamed("")
}
func NewNamed(name string) *Type {
return &Type{
dirty: set{},
processing: set{},
cond: sync.NewCond(&sync.Mutex{}),
metrics: newQueueMetrics(name),
}
}
// Type is a work queue (see the package comment).
type Type struct {
// queue defines the order in which we will work on items. Every
// element of queue should be in the dirty set and not in the
// processing set.
queue []t
// dirty defines all of the items that need to be processed.
dirty set
// Things that are currently being processed are in the processing set.
// These things may be simultaneously in the dirty set. When we finish
// processing something and remove it from this set, we'll check if
// it's in the dirty set, and if so, add it to the queue.
processing set
cond *sync.Cond
shuttingDown bool
metrics queueMetrics
}
type empty struct{}
type t interface{}
type set map[t]empty
func (s set) has(item t) bool {
_, exists := s[item]
return exists
}
func (s set) insert(item t) {
s[item] = empty{}
}
func (s set) delete(item t) {
delete(s, item)
}
// Add marks item as needing processing.
func (q *Type) Add(item interface{}) {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.shuttingDown {
return
}
if q.dirty.has(item) {
return
}
q.metrics.add(item)
q.dirty.insert(item)
if q.processing.has(item) {
return
}
q.queue = append(q.queue, item)
q.cond.Signal()
}
// Len returns the current queue length, for informational purposes only. You
// shouldn't e.g. gate a call to Add() or Get() on Len() being a particular
// value, that can't be synchronized properly.
func (q *Type) Len() int {
q.cond.L.Lock()
defer q.cond.L.Unlock()
return len(q.queue)
}
// Get blocks until it can return an item to be processed. If shutdown = true,
// the caller should end their goroutine. You must call Done with item when you
// have finished processing it.
func (q *Type) Get() (item interface{}, shutdown bool) {
q.cond.L.Lock()
defer q.cond.L.Unlock()
for len(q.queue) == 0 && !q.shuttingDown {
q.cond.Wait()
}
if len(q.queue) == 0 {
// We must be shutting down.
return nil, true
}
item, q.queue = q.queue[0], q.queue[1:]
q.metrics.get(item)
q.processing.insert(item)
q.dirty.delete(item)
return item, false
}
// Done marks item as done processing, and if it has been marked as dirty again
// while it was being processed, it will be re-added to the queue for
// re-processing.
func (q *Type) Done(item interface{}) {
q.cond.L.Lock()
defer q.cond.L.Unlock()
q.metrics.done(item)
q.processing.delete(item)
if q.dirty.has(item) {
q.queue = append(q.queue, item)
q.cond.Signal()
}
}
// ShutDown will cause q to ignore all new items added to it. As soon as the
// worker goroutines have drained the existing items in the queue, they will be
// instructed to exit.
func (q *Type) ShutDown() {
q.cond.L.Lock()
defer q.cond.L.Unlock()
q.shuttingDown = true
q.cond.Broadcast()
}
func (q *Type) ShuttingDown() bool {
q.cond.L.Lock()
defer q.cond.L.Unlock()
return q.shuttingDown
}

View File

@ -1,161 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue_test
import (
"sync"
"testing"
"time"
"k8s.io/client-go/util/workqueue"
)
func TestBasic(t *testing.T) {
// If something is seriously wrong this test will never complete.
q := workqueue.New()
// Start producers
const producers = 50
producerWG := sync.WaitGroup{}
producerWG.Add(producers)
for i := 0; i < producers; i++ {
go func(i int) {
defer producerWG.Done()
for j := 0; j < 50; j++ {
q.Add(i)
time.Sleep(time.Millisecond)
}
}(i)
}
// Start consumers
const consumers = 10
consumerWG := sync.WaitGroup{}
consumerWG.Add(consumers)
for i := 0; i < consumers; i++ {
go func(i int) {
defer consumerWG.Done()
for {
item, quit := q.Get()
if item == "added after shutdown!" {
t.Errorf("Got an item added after shutdown.")
}
if quit {
return
}
t.Logf("Worker %v: begin processing %v", i, item)
time.Sleep(3 * time.Millisecond)
t.Logf("Worker %v: done processing %v", i, item)
q.Done(item)
}
}(i)
}
producerWG.Wait()
q.ShutDown()
q.Add("added after shutdown!")
consumerWG.Wait()
}
func TestAddWhileProcessing(t *testing.T) {
q := workqueue.New()
// Start producers
const producers = 50
producerWG := sync.WaitGroup{}
producerWG.Add(producers)
for i := 0; i < producers; i++ {
go func(i int) {
defer producerWG.Done()
q.Add(i)
}(i)
}
// Start consumers
const consumers = 10
consumerWG := sync.WaitGroup{}
consumerWG.Add(consumers)
for i := 0; i < consumers; i++ {
go func(i int) {
defer consumerWG.Done()
// Every worker will re-add every item up to two times.
// This tests the dirty-while-processing case.
counters := map[interface{}]int{}
for {
item, quit := q.Get()
if quit {
return
}
counters[item]++
if counters[item] < 2 {
q.Add(item)
}
q.Done(item)
}
}(i)
}
producerWG.Wait()
q.ShutDown()
consumerWG.Wait()
}
func TestLen(t *testing.T) {
q := workqueue.New()
q.Add("foo")
if e, a := 1, q.Len(); e != a {
t.Errorf("Expected %v, got %v", e, a)
}
q.Add("bar")
if e, a := 2, q.Len(); e != a {
t.Errorf("Expected %v, got %v", e, a)
}
q.Add("foo") // should not increase the queue length.
if e, a := 2, q.Len(); e != a {
t.Errorf("Expected %v, got %v", e, a)
}
}
func TestReinsert(t *testing.T) {
q := workqueue.New()
q.Add("foo")
// Start processing
i, _ := q.Get()
if i != "foo" {
t.Errorf("Expected %v, got %v", "foo", i)
}
// Add it back while processing
q.Add(i)
// Finish it up
q.Done(i)
// It should be back on the queue
i, _ = q.Get()
if i != "foo" {
t.Errorf("Expected %v, got %v", "foo", i)
}
// Finish that one up
q.Done(i)
if a := q.Len(); a != 0 {
t.Errorf("Expected queue to be empty. Has %v items", a)
}
}

View File

@ -1,69 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
// RateLimitingInterface is an interface that rate limits items being added to the queue.
type RateLimitingInterface interface {
DelayingInterface
// AddRateLimited adds an item to the workqueue after the rate limiter says its ok
AddRateLimited(item interface{})
// Forget indicates that an item is finished being retried. Doesn't matter whether its for perm failing
// or for success, we'll stop the rate limiter from tracking it. This only clears the `rateLimiter`, you
// still have to call `Done` on the queue.
Forget(item interface{})
// NumRequeues returns back how many times the item was requeued
NumRequeues(item interface{}) int
}
// NewRateLimitingQueue constructs a new workqueue with rateLimited queuing ability
// Remember to call Forget! If you don't, you may end up tracking failures forever.
func NewRateLimitingQueue(rateLimiter RateLimiter) RateLimitingInterface {
return &rateLimitingType{
DelayingInterface: NewDelayingQueue(),
rateLimiter: rateLimiter,
}
}
func NewNamedRateLimitingQueue(rateLimiter RateLimiter, name string) RateLimitingInterface {
return &rateLimitingType{
DelayingInterface: NewNamedDelayingQueue(name),
rateLimiter: rateLimiter,
}
}
// rateLimitingType wraps an Interface and provides rateLimited re-enquing
type rateLimitingType struct {
DelayingInterface
rateLimiter RateLimiter
}
// AddRateLimited AddAfter's the item based on the time when the rate limiter says its ok
func (q *rateLimitingType) AddRateLimited(item interface{}) {
q.DelayingInterface.AddAfter(item, q.rateLimiter.When(item))
}
func (q *rateLimitingType) NumRequeues(item interface{}) int {
return q.rateLimiter.NumRequeues(item)
}
func (q *rateLimitingType) Forget(item interface{}) {
q.rateLimiter.Forget(item)
}

View File

@ -1,75 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workqueue
import (
"testing"
"time"
"k8s.io/apimachinery/pkg/util/clock"
)
func TestRateLimitingQueue(t *testing.T) {
limiter := NewItemExponentialFailureRateLimiter(1*time.Millisecond, 1*time.Second)
queue := NewRateLimitingQueue(limiter).(*rateLimitingType)
fakeClock := clock.NewFakeClock(time.Now())
delayingQueue := &delayingType{
Interface: New(),
clock: fakeClock,
heartbeat: fakeClock.Tick(maxWait),
stopCh: make(chan struct{}),
waitingForAddCh: make(chan *waitFor, 1000),
metrics: newRetryMetrics(""),
}
queue.DelayingInterface = delayingQueue
queue.AddRateLimited("one")
waitEntry := <-delayingQueue.waitingForAddCh
if e, a := 1*time.Millisecond, waitEntry.readyAt.Sub(fakeClock.Now()); e != a {
t.Errorf("expected %v, got %v", e, a)
}
queue.AddRateLimited("one")
waitEntry = <-delayingQueue.waitingForAddCh
if e, a := 2*time.Millisecond, waitEntry.readyAt.Sub(fakeClock.Now()); e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := 2, queue.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
queue.AddRateLimited("two")
waitEntry = <-delayingQueue.waitingForAddCh
if e, a := 1*time.Millisecond, waitEntry.readyAt.Sub(fakeClock.Now()); e != a {
t.Errorf("expected %v, got %v", e, a)
}
queue.AddRateLimited("two")
waitEntry = <-delayingQueue.waitingForAddCh
if e, a := 2*time.Millisecond, waitEntry.readyAt.Sub(fakeClock.Now()); e != a {
t.Errorf("expected %v, got %v", e, a)
}
queue.Forget("one")
if e, a := 0, queue.NumRequeues("one"); e != a {
t.Errorf("expected %v, got %v", e, a)
}
queue.AddRateLimited("one")
waitEntry = <-delayingQueue.waitingForAddCh
if e, a := 1*time.Millisecond, waitEntry.readyAt.Sub(fakeClock.Now()); e != a {
t.Errorf("expected %v, got %v", e, a)
}
}