reconcile merge

Signed-off-by: Huamin Chen <hchen@redhat.com>
This commit is contained in:
Huamin Chen
2019-01-15 16:20:41 +00:00
parent 85b8415024
commit e46099a504
2425 changed files with 271763 additions and 40453 deletions

View File

@ -1,48 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"cache_test.go",
"round_trippers_test.go",
"transport_test.go",
],
embed = [":go_default_library"],
)
go_library(
name = "go_default_library",
srcs = [
"cache.go",
"config.go",
"round_trippers.go",
"transport.go",
],
importpath = "k8s.io/client-go/transport",
deps = [
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/net:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//staging/src/k8s.io/client-go/transport/spdy:all-srcs",
],
tags = ["automanaged"],
)

View File

@ -43,7 +43,9 @@ type tlsCacheKey struct {
caData string
certData string
keyData string
getCert string
serverName string
dial string
}
func (t tlsCacheKey) String() string {
@ -51,7 +53,7 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s", t.insecure, t.caData, t.certData, keyText, t.serverName)
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@ -75,7 +77,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return nil, err
}
// The options didn't require a custom TLS config
if tlsConfig == nil {
if tlsConfig == nil && config.Dial == nil {
return http.DefaultTransport, nil
}
@ -84,7 +86,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
}).DialContext
}
// Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@ -92,7 +94,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost,
Dial: dial,
DialContext: dial,
})
return c.transports[key], nil
}
@ -108,6 +110,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) {
caData: string(c.TLS.CAData),
certData: string(c.TLS.CertData),
keyData: string(c.TLS.KeyData),
getCert: fmt.Sprintf("%p", c.TLS.GetCert),
serverName: c.TLS.ServerName,
dial: fmt.Sprintf("%p", c.Dial),
}, nil
}

View File

@ -17,6 +17,9 @@ limitations under the License.
package transport
import (
"context"
"crypto/tls"
"net"
"net/http"
"testing"
)
@ -51,8 +54,12 @@ func TestTLSConfigKey(t *testing.T) {
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil }
uniqueConfigurations := map[string]*Config{
"no tls": {},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@ -101,14 +108,27 @@ func TestTLSConfigKey(t *testing.T) {
KeyData: []byte{1},
},
},
"getCert1": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
},
},
"getCert2": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
"getCert1, key 2": {
TLS: TLSConfig{
KeyData: []byte{2},
GetCert: getCert,
},
},
}
for nameA, valueA := range uniqueConfigurations {
for nameB, valueB := range uniqueConfigurations {
// Don't compare to ourselves
if nameA == nameB {
continue
}
keyA, err := tlsConfigKey(valueA)
if err != nil {
t.Errorf("Unexpected error for %q: %v", nameA, err)
@ -119,6 +139,15 @@ func TestTLSConfigKey(t *testing.T) {
t.Errorf("Unexpected error for %q: %v", nameB, err)
continue
}
// Make sure we get the same key on the same config
if nameA == nameB {
if keyA != keyB {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
}
continue
}
if keyA == keyB {
t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue

View File

@ -17,6 +17,8 @@ limitations under the License.
package transport
import (
"context"
"crypto/tls"
"net"
"net/http"
)
@ -53,7 +55,7 @@ type Config struct {
WrapTransport func(rt http.RoundTripper) http.RoundTripper
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options
@ -83,7 +85,12 @@ func (c *Config) HasTokenAuth() bool {
// HasCertAuth returns whether the configuration has certificate authentication or not.
func (c *Config) HasCertAuth() bool {
return len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0
return (len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0) && (len(c.TLS.KeyData) != 0 || len(c.TLS.KeyFile) != 0)
}
// HasCertCallbacks returns whether the configuration has certificate callback or not.
func (c *Config) HasCertCallback() bool {
return c.TLS.GetCert != nil
}
// TLSConfig holds the information needed to set up a TLS transport.
@ -98,4 +105,6 @@ type TLSConfig struct {
CAData []byte // Bytes of the PEM-encoded server trusted root certificates. Supercedes CAFile.
CertData []byte // Bytes of the PEM-encoded client certificate. Supercedes CertFile.
KeyData []byte // Bytes of the PEM-encoded client key. Supercedes KeyFile.
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
}

View File

@ -22,7 +22,7 @@ import (
"strings"
"time"
"github.com/golang/glog"
"k8s.io/klog"
utilnet "k8s.io/apimachinery/pkg/util/net"
)
@ -62,13 +62,13 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip
// DebugWrappers wraps a round tripper and logs based on the current log level.
func DebugWrappers(rt http.RoundTripper) http.RoundTripper {
switch {
case bool(glog.V(9)):
case bool(klog.V(9)):
rt = newDebuggingRoundTripper(rt, debugCurlCommand, debugURLTiming, debugResponseHeaders)
case bool(glog.V(8)):
case bool(klog.V(8)):
rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus, debugResponseHeaders)
case bool(glog.V(7)):
case bool(klog.V(7)):
rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus)
case bool(glog.V(6)):
case bool(klog.V(6)):
rt = newDebuggingRoundTripper(rt, debugURLTiming)
}
@ -129,7 +129,7 @@ func SetAuthProxyHeaders(req *http.Request, username string, groups []string, ex
}
for key, values := range extra {
for _, value := range values {
req.Header.Add("X-Remote-Extra-"+key, value)
req.Header.Add("X-Remote-Extra-"+headerKeyEscape(key), value)
}
}
}
@ -138,7 +138,7 @@ func (rt *authProxyRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@ -166,7 +166,7 @@ func (rt *userAgentRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@ -197,7 +197,7 @@ func (rt *basicAuthRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@ -246,7 +246,7 @@ func (rt *impersonatingRoundTripper) RoundTrip(req *http.Request) (*http.Respons
}
for k, vv := range rt.impersonate.Extra {
for _, v := range vv {
req.Header.Add(ImpersonateUserExtraHeaderPrefix+k, v)
req.Header.Add(ImpersonateUserExtraHeaderPrefix+headerKeyEscape(k), v)
}
}
@ -257,7 +257,7 @@ func (rt *impersonatingRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.delegate.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.delegate)
}
}
@ -288,7 +288,7 @@ func (rt *bearerAuthRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@ -335,7 +335,7 @@ func (r *requestInfo) toCurl() string {
}
}
return fmt.Sprintf("curl -k -v -X%s %s %s", r.RequestVerb, headers, r.RequestURL)
return fmt.Sprintf("curl -k -v -X%s %s '%s'", r.RequestVerb, headers, r.RequestURL)
}
// debuggingRoundTripper will display information about the requests passing
@ -372,7 +372,7 @@ func (rt *debuggingRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.delegatedRoundTripper.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
glog.Errorf("CancelRequest not implemented")
klog.Errorf("CancelRequest not implemented by %T", rt.delegatedRoundTripper)
}
}
@ -380,17 +380,17 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
reqInfo := newRequestInfo(req)
if rt.levels[debugJustURL] {
glog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL)
klog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL)
}
if rt.levels[debugCurlCommand] {
glog.Infof("%s", reqInfo.toCurl())
klog.Infof("%s", reqInfo.toCurl())
}
if rt.levels[debugRequestHeaders] {
glog.Infof("Request Headers:")
klog.Infof("Request Headers:")
for key, values := range reqInfo.RequestHeaders {
for _, value := range values {
glog.Infof(" %s: %s", key, value)
klog.Infof(" %s: %s", key, value)
}
}
}
@ -402,16 +402,16 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
reqInfo.complete(response, err)
if rt.levels[debugURLTiming] {
glog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
klog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
}
if rt.levels[debugResponseStatus] {
glog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
klog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
}
if rt.levels[debugResponseHeaders] {
glog.Infof("Response Headers:")
klog.Infof("Response Headers:")
for key, values := range reqInfo.ResponseHeaders {
for _, value := range values {
glog.Infof(" %s: %s", key, value)
klog.Infof(" %s: %s", key, value)
}
}
}
@ -422,3 +422,110 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
func (rt *debuggingRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.delegatedRoundTripper
}
func legalHeaderByte(b byte) bool {
return int(b) < len(legalHeaderKeyBytes) && legalHeaderKeyBytes[b]
}
func shouldEscape(b byte) bool {
// url.PathUnescape() returns an error if any '%' is not followed by two
// hexadecimal digits, so we'll intentionally encode it.
return !legalHeaderByte(b) || b == '%'
}
func headerKeyEscape(key string) string {
buf := strings.Builder{}
for i := 0; i < len(key); i++ {
b := key[i]
if shouldEscape(b) {
// %-encode bytes that should be escaped:
// https://tools.ietf.org/html/rfc3986#section-2.1
fmt.Fprintf(&buf, "%%%02X", b)
continue
}
buf.WriteByte(b)
}
return buf.String()
}
// legalHeaderKeyBytes was copied from net/http/lex.go's isTokenTable.
// See https://httpwg.github.io/specs/rfc7230.html#rule.token.separators
var legalHeaderKeyBytes = [127]bool{
'%': true,
'!': true,
'#': true,
'$': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}

View File

@ -18,6 +18,7 @@ package transport
import (
"net/http"
"net/url"
"reflect"
"strings"
"testing"
@ -125,6 +126,32 @@ func TestImpersonationRoundTripper(t *testing.T) {
ImpersonateUserExtraHeaderPrefix + "Second": {"B", "b"},
},
},
{
name: "escape handling",
impersonationConfig: ImpersonationConfig{
UserName: "user",
Extra: map[string][]string{
"test.example.com/thing.thing": {"A", "a"},
},
},
expected: map[string][]string{
ImpersonateUserHeader: {"user"},
ImpersonateUserExtraHeaderPrefix + `Test.example.com%2fthing.thing`: {"A", "a"},
},
},
{
name: "double escape handling",
impersonationConfig: ImpersonationConfig{
UserName: "user",
Extra: map[string][]string{
"test.example.com/thing.thing%20another.thing": {"A", "a"},
},
},
expected: map[string][]string{
ImpersonateUserHeader: {"user"},
ImpersonateUserExtraHeaderPrefix + `Test.example.com%2fthing.thing%2520another.thing`: {"A", "a"},
},
},
}
for _, tc := range tcs {
@ -159,9 +186,10 @@ func TestImpersonationRoundTripper(t *testing.T) {
func TestAuthProxyRoundTripper(t *testing.T) {
for n, tc := range map[string]struct {
username string
groups []string
extra map[string][]string
username string
groups []string
extra map[string][]string
expectedExtra map[string][]string
}{
"allfields": {
username: "user",
@ -170,6 +198,34 @@ func TestAuthProxyRoundTripper(t *testing.T) {
"one": {"alpha", "bravo"},
"two": {"charlie", "delta"},
},
expectedExtra: map[string][]string{
"one": {"alpha", "bravo"},
"two": {"charlie", "delta"},
},
},
"escaped extra": {
username: "user",
groups: []string{"groupA", "groupB"},
extra: map[string][]string{
"one": {"alpha", "bravo"},
"example.com/two": {"charlie", "delta"},
},
expectedExtra: map[string][]string{
"one": {"alpha", "bravo"},
"example.com%2ftwo": {"charlie", "delta"},
},
},
"double escaped extra": {
username: "user",
groups: []string{"groupA", "groupB"},
extra: map[string][]string{
"one": {"alpha", "bravo"},
"example.com/two%20three": {"charlie", "delta"},
},
expectedExtra: map[string][]string{
"one": {"alpha", "bravo"},
"example.com%2ftwo%2520three": {"charlie", "delta"},
},
},
} {
rt := &testRoundTripper{}
@ -210,9 +266,64 @@ func TestAuthProxyRoundTripper(t *testing.T) {
actualExtra[extraKey] = append(actualExtra[key], values...)
}
}
if e, a := tc.extra, actualExtra; !reflect.DeepEqual(e, a) {
if e, a := tc.expectedExtra, actualExtra; !reflect.DeepEqual(e, a) {
t.Errorf("%s expected %v, got %v", n, e, a)
continue
}
}
}
// TestHeaderEscapeRoundTrip tests to see if foo == url.PathUnescape(headerEscape(foo))
// This behavior is important for client -> API server transmission of extra values.
func TestHeaderEscapeRoundTrip(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
key string
}{
{
name: "alpha",
key: "alphabetical",
},
{
name: "alphanumeric",
key: "alph4num3r1c",
},
{
name: "percent encoded",
key: "percent%20encoded",
},
{
name: "almost percent encoded",
key: "almost%zzpercent%xxencoded",
},
{
name: "illegal char & percent encoding",
key: "example.com/percent%20encoded",
},
{
name: "weird unicode stuff",
key: "example.com/ᛒᚥᛏᛖᚥᚢとロビン",
},
{
name: "header legal chars",
key: "abc123!#$+.-_*\\^`~|'",
},
{
name: "legal path, illegal header",
key: "@=:",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
escaped := headerKeyEscape(tc.key)
unescaped, err := url.PathUnescape(escaped)
if err != nil {
t.Fatalf("url.PathUnescape(%q) returned error: %v", escaped, err)
}
if tc.key != unescaped {
t.Errorf("url.PathUnescape(headerKeyEscape(%q)) returned %q, wanted %q", tc.key, unescaped, tc.key)
}
})
}
}

View File

@ -1,30 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["spdy.go"],
importpath = "k8s.io/client-go/transport/spdy",
deps = [
"//vendor/k8s.io/apimachinery/pkg/util/httpstream:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy:go_default_library",
"//vendor/k8s.io/client-go/rest:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -38,7 +38,7 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er
if err != nil {
return nil, nil, err
}
upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true)
upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true, false)
wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper)
if err != nil {
return nil, nil, err

View File

@ -28,7 +28,7 @@ import (
// or transport level security defined by the provided Config.
func New(config *Config) (http.RoundTripper, error) {
// Set transport level security
if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.TLS.Insecure) {
if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
}
@ -52,7 +52,7 @@ func New(config *Config) (http.RoundTripper, error) {
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) {
if !(c.HasCA() || c.HasCertAuth() || c.TLS.Insecure) {
if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) {
return nil, nil
}
if c.HasCA() && c.TLS.Insecure {
@ -75,12 +75,40 @@ func TLSConfigFor(c *Config) (*tls.Config, error) {
tlsConfig.RootCAs = rootCertPool(c.TLS.CAData)
}
var staticCert *tls.Certificate
if c.HasCertAuth() {
// If key/cert were provided, verify them before setting up
// tlsConfig.GetClientCertificate.
cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
staticCert = &cert
}
if c.HasCertAuth() || c.HasCertCallback() {
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Note: static key/cert data always take precedence over cert
// callback.
if staticCert != nil {
return staticCert, nil
}
if c.HasCertCallback() {
cert, err := c.TLS.GetCert()
if err != nil {
return nil, err
}
// GetCert may return empty value, meaning no cert.
if cert != nil {
return cert, nil
}
}
// Both c.TLS.CertData/KeyData were unset and GetCert didn't return
// anything. Return an empty tls.Certificate, no client cert will
// be sent to the server.
return &tls.Certificate{}, nil
}
}
return tlsConfig, nil

View File

@ -17,6 +17,8 @@ limitations under the License.
package transport
import (
"crypto/tls"
"errors"
"net/http"
"testing"
)
@ -91,16 +93,37 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
func TestNew(t *testing.T) {
testCases := map[string]struct {
Config *Config
Err bool
TLS bool
Default bool
Config *Config
Err bool
TLS bool
TLSCert bool
TLSErr bool
Default bool
Insecure bool
DefaultRoots bool
}{
"default transport": {
Default: true,
Config: &Config{},
},
"insecure": {
TLS: true,
Insecure: true,
DefaultRoots: true,
Config: &Config{TLS: TLSConfig{
Insecure: true,
}},
},
"server name": {
TLS: true,
DefaultRoots: true,
Config: &Config{TLS: TLSConfig{
ServerName: "foo",
}},
},
"ca transport": {
TLS: true,
Config: &Config{
@ -128,7 +151,8 @@ func TestNew(t *testing.T) {
},
"cert transport": {
TLS: true,
TLS: true,
TLSCert: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
@ -158,7 +182,8 @@ func TestNew(t *testing.T) {
},
},
"key data overriding bad file cert transport": {
TLS: true,
TLS: true,
TLSCert: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
@ -168,37 +193,120 @@ func TestNew(t *testing.T) {
},
},
},
"callback cert and key": {
TLS: true,
TLSCert: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
return &crt, err
},
},
},
},
"cert callback error": {
TLS: true,
TLSCert: true,
TLSErr: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, errors.New("GetCert failure")
},
},
},
},
"cert data overrides empty callback result": {
TLS: true,
TLSCert: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, nil
},
CertData: []byte(certData),
KeyData: []byte(keyData),
},
},
},
"callback returns nothing": {
TLS: true,
TLSCert: true,
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, nil
},
},
},
},
}
for k, testCase := range testCases {
transport, err := New(testCase.Config)
switch {
case testCase.Err && err == nil:
t.Errorf("%s: unexpected non-error", k)
continue
case !testCase.Err && err != nil:
t.Errorf("%s: unexpected error: %v", k, err)
continue
}
t.Run(k, func(t *testing.T) {
rt, err := New(testCase.Config)
switch {
case testCase.Err && err == nil:
t.Fatal("unexpected non-error")
case !testCase.Err && err != nil:
t.Fatalf("unexpected error: %v", err)
}
if testCase.Err {
return
}
switch {
case testCase.Default && transport != http.DefaultTransport:
t.Errorf("%s: expected the default transport, got %#v", k, transport)
continue
case !testCase.Default && transport == http.DefaultTransport:
t.Errorf("%s: expected non-default transport, got %#v", k, transport)
continue
}
switch {
case testCase.Default && rt != http.DefaultTransport:
t.Fatalf("got %#v, expected the default transport", rt)
case !testCase.Default && rt == http.DefaultTransport:
t.Fatalf("got %#v, expected non-default transport", rt)
}
// We only know how to check TLSConfig on http.Transports
if transport, ok := transport.(*http.Transport); ok {
// We only know how to check TLSConfig on http.Transports
transport := rt.(*http.Transport)
switch {
case testCase.TLS && transport.TLSClientConfig == nil:
t.Errorf("%s: expected TLSClientConfig, got %#v", k, transport)
continue
t.Fatalf("got %#v, expected TLSClientConfig", transport)
case !testCase.TLS && transport.TLSClientConfig != nil:
t.Errorf("%s: expected no TLSClientConfig, got %#v", k, transport)
continue
t.Fatalf("got %#v, expected no TLSClientConfig", transport)
}
}
if !testCase.TLS {
return
}
switch {
case testCase.DefaultRoots && transport.TLSClientConfig.RootCAs != nil:
t.Fatalf("got %#v, expected nil root CAs", transport.TLSClientConfig.RootCAs)
case !testCase.DefaultRoots && transport.TLSClientConfig.RootCAs == nil:
t.Fatalf("got %#v, expected non-nil root CAs", transport.TLSClientConfig.RootCAs)
}
switch {
case testCase.Insecure != transport.TLSClientConfig.InsecureSkipVerify:
t.Fatalf("got %#v, expected %#v", transport.TLSClientConfig.InsecureSkipVerify, testCase.Insecure)
}
switch {
case testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate == nil:
t.Fatalf("got %#v, expected TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
case !testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate != nil:
t.Fatalf("got %#v, expected no TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
}
if !testCase.TLSCert {
return
}
_, err = transport.TLSClientConfig.GetClientCertificate(nil)
switch {
case testCase.TLSErr && err == nil:
t.Error("got nil error from GetClientCertificate, expected non-nil")
case !testCase.TLSErr && err != nil:
t.Errorf("got error from GetClientCertificate: %q, expected nil", err)
}
})
}
}