Fresh dep ensure

This commit is contained in:
Mike Cronce
2018-11-26 13:23:56 -05:00
parent 93cb8a04d7
commit 407478ab9a
9016 changed files with 551394 additions and 279685 deletions

View File

@ -1,50 +0,0 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_test(
name = "go_default_test",
srcs = [
"http_test.go",
"interface_test.go",
"port_range_test.go",
"port_split_test.go",
"util_test.go",
],
embed = [":go_default_library"],
deps = ["//vendor/github.com/spf13/pflag:go_default_library"],
)
go_library(
name = "go_default_library",
srcs = [
"http.go",
"interface.go",
"port_range.go",
"port_split.go",
"util.go",
],
importpath = "k8s.io/apimachinery/pkg/util/net",
deps = [
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/golang.org/x/net/http2:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/sets:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -19,6 +19,7 @@ package net
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
@ -30,8 +31,8 @@ import (
"strconv"
"strings"
"github.com/golang/glog"
"golang.org/x/net/http2"
"k8s.io/klog"
)
// JoinPreservingTrailingSlash does a path.Join of the specified elements,
@ -61,6 +62,9 @@ func JoinPreservingTrailingSlash(elem ...string) string {
// differentiate probable errors in connection behavior between normal "this is
// disconnected" should use the method.
func IsProbableEOF(err error) bool {
if err == nil {
return false
}
if uerr, ok := err.(*url.Error); ok {
err = uerr.Err
}
@ -87,8 +91,9 @@ func SetOldTransportDefaults(t *http.Transport) *http.Transport {
// ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
if t.Dial == nil {
t.Dial = defaultTransport.Dial
// If no custom dialer is set, use the default context dialer
if t.DialContext == nil && t.Dial == nil {
t.DialContext = defaultTransport.DialContext
}
if t.TLSHandshakeTimeout == 0 {
t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
@ -102,10 +107,10 @@ func SetTransportDefaults(t *http.Transport) *http.Transport {
t = SetOldTransportDefaults(t)
// Allow clients to disable http2 if needed.
if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 {
glog.Infof("HTTP2 has been explicitly disabled")
klog.Infof("HTTP2 has been explicitly disabled")
} else {
if err := http2.ConfigureTransport(t); err != nil {
glog.Warningf("Transport failed http2 configuration: %v", err)
klog.Warningf("Transport failed http2 configuration: %v", err)
}
}
return t
@ -116,7 +121,7 @@ type RoundTripperWrapper interface {
WrappedRoundTripper() http.RoundTripper
}
type DialFunc func(net, addr string) (net.Conn, error)
type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
func DialerFor(transport http.RoundTripper) (DialFunc, error) {
if transport == nil {
@ -125,7 +130,18 @@ func DialerFor(transport http.RoundTripper) (DialFunc, error) {
switch transport := transport.(type) {
case *http.Transport:
return transport.Dial, nil
// transport.DialContext takes precedence over transport.Dial
if transport.DialContext != nil {
return transport.DialContext, nil
}
// adapt transport.Dial to the DialWithContext signature
if transport.Dial != nil {
return func(ctx context.Context, net, addr string) (net.Conn, error) {
return transport.Dial(net, addr)
}, nil
}
// otherwise return nil
return nil, nil
case RoundTripperWrapper:
return DialerFor(transport.WrappedRoundTripper())
default:
@ -163,10 +179,8 @@ func FormatURL(scheme string, host string, port int, path string) *url.URL {
}
func GetHTTPClient(req *http.Request) string {
if userAgent, ok := req.Header["User-Agent"]; ok {
if len(userAgent) > 0 {
return userAgent[0]
}
if ua := req.UserAgent(); len(ua) != 0 {
return ua
}
return "unknown"
}
@ -307,9 +321,10 @@ type Dialer interface {
// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to
// originalLocation). It returns the opened net.Conn and the raw response bytes.
func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer) (net.Conn, []byte, error) {
// If requireSameHostRedirects is true, only redirects to the same host are permitted.
func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) {
const (
maxRedirects = 10
maxRedirects = 9 // Fail on the 10th redirect
maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers
)
@ -353,7 +368,7 @@ redirectLoop:
resp, err := http.ReadResponse(respReader, nil)
if err != nil {
// Unable to read the backend response; let the client handle it.
glog.Warningf("Error reading backend response: %v", err)
klog.Warningf("Error reading backend response: %v", err)
break redirectLoop
}
@ -373,10 +388,6 @@ redirectLoop:
resp.Body.Close() // not used
// Reset the connection.
intermediateConn.Close()
intermediateConn = nil
// Prepare to follow the redirect.
redirectStr := resp.Header.Get("Location")
if redirectStr == "" {
@ -390,6 +401,15 @@ redirectLoop:
if err != nil {
return nil, nil, fmt.Errorf("malformed Location header: %v", err)
}
// Only follow redirects to the same host. Otherwise, propagate the redirect response back.
if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() {
break redirectLoop
}
// Reset the connection.
intermediateConn.Close()
intermediateConn = nil
}
connToReturn := intermediateConn

View File

@ -19,14 +19,23 @@ limitations under the License.
package net
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/wait"
)
func TestGetClientIP(t *testing.T) {
@ -280,3 +289,153 @@ func TestJoinPreservingTrailingSlash(t *testing.T) {
})
}
}
func TestConnectWithRedirects(t *testing.T) {
tests := []struct {
desc string
redirects []string
method string // initial request method, empty == GET
expectError bool
expectedRedirects int
newPort bool // special case different port test
}{{
desc: "relative redirects allowed",
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "redirects to the same host are allowed",
redirects: []string{"http://HOST/ok"}, // HOST replaced with server address in test
expectedRedirects: 1,
}, {
desc: "POST redirects to GET",
method: http.MethodPost,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "PUT redirects to GET",
method: http.MethodPut,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "DELETE redirects to GET",
method: http.MethodDelete,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "9 redirects are allowed",
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9"},
expectedRedirects: 9,
}, {
desc: "10 redirects are forbidden",
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9", "/10"},
expectError: true,
}, {
desc: "redirect to different host are prevented",
redirects: []string{"http://example.com/foo"},
expectedRedirects: 0,
}, {
desc: "multiple redirect to different host forbidden",
redirects: []string{"/1", "/2", "/3", "http://example.com/foo"},
expectedRedirects: 3,
}, {
desc: "redirect to different port is allowed",
redirects: []string{"http://HOST/foo"},
expectedRedirects: 1,
newPort: true,
}}
const resultString = "Test output"
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
redirectCount := 0
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Verify redirect request.
if redirectCount > 0 {
expectedURL, err := url.Parse(test.redirects[redirectCount-1])
require.NoError(t, err, "test URL error")
assert.Equal(t, req.URL.Path, expectedURL.Path, "unknown redirect path")
assert.Equal(t, http.MethodGet, req.Method, "redirects must always be GET")
}
if redirectCount < len(test.redirects) {
http.Redirect(w, req, test.redirects[redirectCount], http.StatusFound)
redirectCount++
} else if redirectCount == len(test.redirects) {
w.Write([]byte(resultString))
} else {
t.Errorf("unexpected number of redirects %d to %s", redirectCount, req.URL.String())
}
}))
defer s.Close()
u, err := url.Parse(s.URL)
require.NoError(t, err, "Error parsing server URL")
host := u.Host
// Special case new-port test with a secondary server.
if test.newPort {
s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(resultString))
}))
defer s2.Close()
u2, err := url.Parse(s2.URL)
require.NoError(t, err, "Error parsing secondary server URL")
// Sanity check: secondary server uses same hostname, different port.
require.Equal(t, u.Hostname(), u2.Hostname(), "sanity check: same hostname")
require.NotEqual(t, u.Port(), u2.Port(), "sanity check: different port")
// Redirect to the secondary server.
host = u2.Host
}
// Update redirect URLs with actual host.
for i := range test.redirects {
test.redirects[i] = strings.Replace(test.redirects[i], "HOST", host, 1)
}
method := test.method
if method == "" {
method = http.MethodGet
}
netdialer := &net.Dialer{
Timeout: wait.ForeverTestTimeout,
KeepAlive: wait.ForeverTestTimeout,
}
dialer := DialerFunc(func(req *http.Request) (net.Conn, error) {
conn, err := netdialer.Dial("tcp", req.URL.Host)
if err != nil {
return conn, err
}
if err = req.Write(conn); err != nil {
require.NoError(t, conn.Close())
return nil, fmt.Errorf("error sending request: %v", err)
}
return conn, err
})
conn, rawResponse, err := ConnectWithRedirects(method, u, http.Header{} /*body*/, nil, dialer, true)
if test.expectError {
require.Error(t, err, "expected request error")
return
}
require.NoError(t, err, "unexpected request error")
assert.NoError(t, conn.Close(), "error closing connection")
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawResponse)), nil)
require.NoError(t, err, "unexpected request error")
result, err := ioutil.ReadAll(resp.Body)
require.NoError(t, resp.Body.Close())
if test.expectedRedirects < len(test.redirects) {
// Expect the last redirect to be returned.
assert.Equal(t, http.StatusFound, resp.StatusCode, "Final response is not a redirect")
assert.Equal(t, test.redirects[len(test.redirects)-1], resp.Header.Get("Location"))
assert.NotEqual(t, resultString, string(result), "wrong content")
} else {
assert.Equal(t, resultString, string(result), "stream content does not match")
}
})
}
}

View File

@ -26,7 +26,7 @@ import (
"strings"
"github.com/golang/glog"
"k8s.io/klog"
)
type AddressFamily uint
@ -53,6 +53,28 @@ type RouteFile struct {
parse func(input io.Reader) ([]Route, error)
}
// noRoutesError can be returned by ChooseBindAddress() in case of no routes
type noRoutesError struct {
message string
}
func (e noRoutesError) Error() string {
return e.message
}
// IsNoRoutesError checks if an error is of type noRoutesError
func IsNoRoutesError(err error) bool {
if err == nil {
return false
}
switch err.(type) {
case noRoutesError:
return true
default:
return false
}
}
var (
v4File = RouteFile{name: ipv4RouteFile, parse: getIPv4DefaultRoutes}
v6File = RouteFile{name: ipv6RouteFile, parse: getIPv6DefaultRoutes}
@ -171,7 +193,7 @@ func isInterfaceUp(intf *net.Interface) bool {
return false
}
if intf.Flags&net.FlagUp != 0 {
glog.V(4).Infof("Interface %v is up", intf.Name)
klog.V(4).Infof("Interface %v is up", intf.Name)
return true
}
return false
@ -186,20 +208,20 @@ func isLoopbackOrPointToPoint(intf *net.Interface) bool {
func getMatchingGlobalIP(addrs []net.Addr, family AddressFamily) (net.IP, error) {
if len(addrs) > 0 {
for i := range addrs {
glog.V(4).Infof("Checking addr %s.", addrs[i].String())
klog.V(4).Infof("Checking addr %s.", addrs[i].String())
ip, _, err := net.ParseCIDR(addrs[i].String())
if err != nil {
return nil, err
}
if memberOf(ip, family) {
if ip.IsGlobalUnicast() {
glog.V(4).Infof("IP found %v", ip)
klog.V(4).Infof("IP found %v", ip)
return ip, nil
} else {
glog.V(4).Infof("Non-global unicast address found %v", ip)
klog.V(4).Infof("Non-global unicast address found %v", ip)
}
} else {
glog.V(4).Infof("%v is not an IPv%d address", ip, int(family))
klog.V(4).Infof("%v is not an IPv%d address", ip, int(family))
}
}
@ -219,13 +241,13 @@ func getIPFromInterface(intfName string, forFamily AddressFamily, nw networkInte
if err != nil {
return nil, err
}
glog.V(4).Infof("Interface %q has %d addresses :%v.", intfName, len(addrs), addrs)
klog.V(4).Infof("Interface %q has %d addresses :%v.", intfName, len(addrs), addrs)
matchingIP, err := getMatchingGlobalIP(addrs, forFamily)
if err != nil {
return nil, err
}
if matchingIP != nil {
glog.V(4).Infof("Found valid IPv%d address %v for interface %q.", int(forFamily), matchingIP, intfName)
klog.V(4).Infof("Found valid IPv%d address %v for interface %q.", int(forFamily), matchingIP, intfName)
return matchingIP, nil
}
}
@ -253,14 +275,14 @@ func chooseIPFromHostInterfaces(nw networkInterfacer) (net.IP, error) {
return nil, fmt.Errorf("no interfaces found on host.")
}
for _, family := range []AddressFamily{familyIPv4, familyIPv6} {
glog.V(4).Infof("Looking for system interface with a global IPv%d address", uint(family))
klog.V(4).Infof("Looking for system interface with a global IPv%d address", uint(family))
for _, intf := range intfs {
if !isInterfaceUp(&intf) {
glog.V(4).Infof("Skipping: down interface %q", intf.Name)
klog.V(4).Infof("Skipping: down interface %q", intf.Name)
continue
}
if isLoopbackOrPointToPoint(&intf) {
glog.V(4).Infof("Skipping: LB or P2P interface %q", intf.Name)
klog.V(4).Infof("Skipping: LB or P2P interface %q", intf.Name)
continue
}
addrs, err := nw.Addrs(&intf)
@ -268,7 +290,7 @@ func chooseIPFromHostInterfaces(nw networkInterfacer) (net.IP, error) {
return nil, err
}
if len(addrs) == 0 {
glog.V(4).Infof("Skipping: no addresses on interface %q", intf.Name)
klog.V(4).Infof("Skipping: no addresses on interface %q", intf.Name)
continue
}
for _, addr := range addrs {
@ -277,15 +299,15 @@ func chooseIPFromHostInterfaces(nw networkInterfacer) (net.IP, error) {
return nil, fmt.Errorf("Unable to parse CIDR for interface %q: %s", intf.Name, err)
}
if !memberOf(ip, family) {
glog.V(4).Infof("Skipping: no address family match for %q on interface %q.", ip, intf.Name)
klog.V(4).Infof("Skipping: no address family match for %q on interface %q.", ip, intf.Name)
continue
}
// TODO: Decide if should open up to allow IPv6 LLAs in future.
if !ip.IsGlobalUnicast() {
glog.V(4).Infof("Skipping: non-global address %q on interface %q.", ip, intf.Name)
klog.V(4).Infof("Skipping: non-global address %q on interface %q.", ip, intf.Name)
continue
}
glog.V(4).Infof("Found global unicast address %q on interface %q.", ip, intf.Name)
klog.V(4).Infof("Found global unicast address %q on interface %q.", ip, intf.Name)
return ip, nil
}
}
@ -347,7 +369,9 @@ func getAllDefaultRoutes() ([]Route, error) {
v6Routes, _ := v6File.extract()
routes = append(routes, v6Routes...)
if len(routes) == 0 {
return nil, fmt.Errorf("No default routes.")
return nil, noRoutesError{
message: fmt.Sprintf("no default routes found in %q or %q", v4File.name, v6File.name),
}
}
return routes, nil
}
@ -357,23 +381,23 @@ func getAllDefaultRoutes() ([]Route, error) {
// an IPv4 IP, and then will look at each IPv6 route for an IPv6 IP.
func chooseHostInterfaceFromRoute(routes []Route, nw networkInterfacer) (net.IP, error) {
for _, family := range []AddressFamily{familyIPv4, familyIPv6} {
glog.V(4).Infof("Looking for default routes with IPv%d addresses", uint(family))
klog.V(4).Infof("Looking for default routes with IPv%d addresses", uint(family))
for _, route := range routes {
if route.Family != family {
continue
}
glog.V(4).Infof("Default route transits interface %q", route.Interface)
klog.V(4).Infof("Default route transits interface %q", route.Interface)
finalIP, err := getIPFromInterface(route.Interface, family, nw)
if err != nil {
return nil, err
}
if finalIP != nil {
glog.V(4).Infof("Found active IP %v ", finalIP)
klog.V(4).Infof("Found active IP %v ", finalIP)
return finalIP, nil
}
}
}
glog.V(4).Infof("No active IP found by looking at default routes")
klog.V(4).Infof("No active IP found by looking at default routes")
return nil, fmt.Errorf("unable to select an IP from default routes.")
}

View File

@ -669,7 +669,7 @@ func TestGetAllDefaultRoutes(t *testing.T) {
expected []Route
errStrFrag string
}{
{"no routes", noInternetConnection, v6noDefaultRoutes, 0, nil, "No default routes"},
{"no routes", noInternetConnection, v6noDefaultRoutes, 0, nil, "no default routes"},
{"only v4 route", gatewayfirst, v6noDefaultRoutes, 1, routeV4, ""},
{"only v6 route", noInternetConnection, v6gatewayfirst, 1, routeV6, ""},
{"v4 and v6 routes", gatewayfirst, v6gatewayfirst, 2, bothRoutes, ""},

View File

@ -43,14 +43,19 @@ func (pr PortRange) String() string {
return fmt.Sprintf("%d-%d", pr.Base, pr.Base+pr.Size-1)
}
// Set parses a string of the form "min-max", inclusive at both ends, and
// Set parses a string of the form "value", "min-max", or "min+offset", inclusive at both ends, and
// sets the PortRange from it. This is part of the flag.Value and pflag.Value
// interfaces.
func (pr *PortRange) Set(value string) error {
value = strings.TrimSpace(value)
const (
SinglePortNotation = 1 << iota
HyphenNotation
PlusNotation
)
// TODO: Accept "80" syntax
// TODO: Accept "80+8" syntax
value = strings.TrimSpace(value)
hyphenIndex := strings.Index(value, "-")
plusIndex := strings.Index(value, "+")
if value == "" {
pr.Base = 0
@ -58,20 +63,51 @@ func (pr *PortRange) Set(value string) error {
return nil
}
hyphenIndex := strings.Index(value, "-")
if hyphenIndex == -1 {
return fmt.Errorf("expected hyphen in port range")
var err error
var low, high int
var notation int
if plusIndex == -1 && hyphenIndex == -1 {
notation |= SinglePortNotation
}
if hyphenIndex != -1 {
notation |= HyphenNotation
}
if plusIndex != -1 {
notation |= PlusNotation
}
var err error
var low int
var high int
low, err = strconv.Atoi(value[:hyphenIndex])
if err == nil {
switch notation {
case SinglePortNotation:
var port int
port, err = strconv.Atoi(value)
if err != nil {
return err
}
low = port
high = port
case HyphenNotation:
low, err = strconv.Atoi(value[:hyphenIndex])
if err != nil {
return err
}
high, err = strconv.Atoi(value[hyphenIndex+1:])
}
if err != nil {
return fmt.Errorf("unable to parse port range: %s: %v", value, err)
if err != nil {
return err
}
case PlusNotation:
var offset int
low, err = strconv.Atoi(value[:plusIndex])
if err != nil {
return err
}
offset, err = strconv.Atoi(value[plusIndex+1:])
if err != nil {
return err
}
high = low + offset
default:
return fmt.Errorf("unable to parse port range: %s", value)
}
if low > 65535 || high > 65535 {

View File

@ -34,13 +34,21 @@ func TestPortRange(t *testing.T) {
{" 100-200 ", true, "100-200", 200, 201},
{"0-0", true, "0-0", 0, 1},
{"", true, "", -1, 0},
{"100", false, "", -1, -1},
{"100", true, "100-100", 100, 101},
{"100 - 200", false, "", -1, -1},
{"-100", false, "", -1, -1},
{"100-", false, "", -1, -1},
{"200-100", false, "", -1, -1},
{"60000-70000", false, "", -1, -1},
{"70000-80000", false, "", -1, -1},
{"70000+80000", false, "", -1, -1},
{"1+0", true, "1-1", 1, 2},
{"0+0", true, "0-0", 0, 1},
{"1+-1", false, "", -1, -1},
{"1-+1", false, "", -1, -1},
{"100+200", true, "100-300", 300, 301},
{"1+65535", false, "", -1, -1},
{"0+65535", true, "0-65535", 65535, 65536},
}
for i := range testCases {
@ -52,7 +60,7 @@ func TestPortRange(t *testing.T) {
t.Errorf("expected success, got %q", err)
continue
} else if err == nil && tc.success == false {
t.Errorf("expected failure")
t.Errorf("expected failure %#v", testCases[i])
continue
} else if tc.success {
if f.String() != tc.expected {