Vendor cleanup

Signed-off-by: Madhu Rajanna <mrajanna@redhat.com>
This commit is contained in:
Madhu Rajanna
2019-01-16 18:11:54 +05:30
parent 661818bd79
commit 0f836c62fa
16816 changed files with 20 additions and 4611100 deletions

View File

@ -1,83 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"sync"
)
const (
shardsCount int = 32
)
type Cache []*cacheShard
func NewCache(maxSize int) Cache {
if maxSize < shardsCount {
maxSize = shardsCount
}
cache := make(Cache, shardsCount)
for i := 0; i < shardsCount; i++ {
cache[i] = &cacheShard{
items: make(map[uint64]interface{}),
maxSize: maxSize / shardsCount,
}
}
return cache
}
func (c Cache) getShard(index uint64) *cacheShard {
return c[index%uint64(shardsCount)]
}
// Returns true if object already existed, false otherwise.
func (c *Cache) Add(index uint64, obj interface{}) bool {
return c.getShard(index).add(index, obj)
}
func (c *Cache) Get(index uint64) (obj interface{}, found bool) {
return c.getShard(index).get(index)
}
type cacheShard struct {
items map[uint64]interface{}
sync.RWMutex
maxSize int
}
// Returns true if object already existed, false otherwise.
func (s *cacheShard) add(index uint64, obj interface{}) bool {
s.Lock()
defer s.Unlock()
_, isOverwrite := s.items[index]
if !isOverwrite && len(s.items) >= s.maxSize {
var randomKey uint64
for randomKey = range s.items {
break
}
delete(s.items, randomKey)
}
s.items[index] = obj
return isOverwrite
}
func (s *cacheShard) get(index uint64) (obj interface{}, found bool) {
s.RLock()
defer s.RUnlock()
obj, found = s.items[index]
return
}

View File

@ -1,90 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"testing"
)
const (
maxTestCacheSize int = shardsCount * 2
)
func ExpectEntry(t *testing.T, cache Cache, index uint64, expectedValue interface{}) bool {
elem, found := cache.Get(index)
if !found {
t.Errorf("Expected to find entry with key %d", index)
return false
} else if elem != expectedValue {
t.Errorf("Expected to find %v, got %v", expectedValue, elem)
return false
}
return true
}
func TestBasic(t *testing.T) {
cache := NewCache(maxTestCacheSize)
cache.Add(1, "xxx")
ExpectEntry(t, cache, 1, "xxx")
}
func TestOverflow(t *testing.T) {
cache := NewCache(maxTestCacheSize)
for i := 0; i < maxTestCacheSize+1; i++ {
cache.Add(uint64(i), "xxx")
}
foundIndexes := make([]uint64, 0)
for i := 0; i < maxTestCacheSize+1; i++ {
_, found := cache.Get(uint64(i))
if found {
foundIndexes = append(foundIndexes, uint64(i))
}
}
if len(foundIndexes) != maxTestCacheSize {
t.Errorf("Expect to find %d elements, got %d %v", maxTestCacheSize, len(foundIndexes), foundIndexes)
}
}
func TestOverwrite(t *testing.T) {
cache := NewCache(maxTestCacheSize)
cache.Add(1, "xxx")
ExpectEntry(t, cache, 1, "xxx")
cache.Add(1, "yyy")
ExpectEntry(t, cache, 1, "yyy")
}
// TestEvict this test will fail sporatically depending on what add()
// selects for the randomKey to be evicted. Ensure that randomKey
// is never the key we most recently added. Since the chance of failure
// on each evict is 50%, if we do it 7 times, it should catch the problem
// if it exists >99% of the time.
func TestEvict(t *testing.T) {
cache := NewCache(shardsCount)
var found bool
for retry := 0; retry < 7; retry++ {
cache.Add(uint64(shardsCount), "xxx")
found = ExpectEntry(t, cache, uint64(shardsCount), "xxx")
if !found {
break
}
cache.Add(0, "xxx")
found = ExpectEntry(t, cache, 0, "xxx")
if !found {
break
}
}
}

View File

@ -1,102 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"sync"
"time"
"github.com/hashicorp/golang-lru"
)
// Clock defines an interface for obtaining the current time
type Clock interface {
Now() time.Time
}
// realClock implements the Clock interface by calling time.Now()
type realClock struct{}
func (realClock) Now() time.Time { return time.Now() }
// LRUExpireCache is a cache that ensures the mostly recently accessed keys are returned with
// a ttl beyond which keys are forcibly expired.
type LRUExpireCache struct {
// clock is used to obtain the current time
clock Clock
cache *lru.Cache
lock sync.Mutex
}
// NewLRUExpireCache creates an expiring cache with the given size
func NewLRUExpireCache(maxSize int) *LRUExpireCache {
return NewLRUExpireCacheWithClock(maxSize, realClock{})
}
// NewLRUExpireCacheWithClock creates an expiring cache with the given size, using the specified clock to obtain the current time.
func NewLRUExpireCacheWithClock(maxSize int, clock Clock) *LRUExpireCache {
cache, err := lru.New(maxSize)
if err != nil {
// if called with an invalid size
panic(err)
}
return &LRUExpireCache{clock: clock, cache: cache}
}
type cacheEntry struct {
value interface{}
expireTime time.Time
}
// Add adds the value to the cache at key with the specified maximum duration.
func (c *LRUExpireCache) Add(key interface{}, value interface{}, ttl time.Duration) {
c.lock.Lock()
defer c.lock.Unlock()
c.cache.Add(key, &cacheEntry{value, c.clock.Now().Add(ttl)})
}
// Get returns the value at the specified key from the cache if it exists and is not
// expired, or returns false.
func (c *LRUExpireCache) Get(key interface{}) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()
e, ok := c.cache.Get(key)
if !ok {
return nil, false
}
if c.clock.Now().After(e.(*cacheEntry).expireTime) {
c.cache.Remove(key)
return nil, false
}
return e.(*cacheEntry).value, true
}
// Remove removes the specified key from the cache if it exists
func (c *LRUExpireCache) Remove(key interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
c.cache.Remove(key)
}
// Keys returns all the keys in the cache, even if they are expired. Subsequent calls to
// get may return not found. It returns all keys from oldest to newest.
func (c *LRUExpireCache) Keys() []interface{} {
c.lock.Lock()
defer c.lock.Unlock()
return c.cache.Keys()
}

View File

@ -1,68 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"testing"
"time"
"k8s.io/apimachinery/pkg/util/clock"
"github.com/golang/groupcache/lru"
)
func expectEntry(t *testing.T, c *LRUExpireCache, key lru.Key, value interface{}) {
result, ok := c.Get(key)
if !ok || result != value {
t.Errorf("Expected cache[%v]: %v, got %v", key, value, result)
}
}
func expectNotEntry(t *testing.T, c *LRUExpireCache, key lru.Key) {
if result, ok := c.Get(key); ok {
t.Errorf("Expected cache[%v] to be empty, got %v", key, result)
}
}
func TestSimpleGet(t *testing.T) {
c := NewLRUExpireCache(10)
c.Add("long-lived", "12345", 10*time.Hour)
expectEntry(t, c, "long-lived", "12345")
}
func TestExpiredGet(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
c := NewLRUExpireCacheWithClock(10, fakeClock)
c.Add("short-lived", "12345", 1*time.Millisecond)
// ensure the entry expired
fakeClock.Step(2 * time.Millisecond)
expectNotEntry(t, c, "short-lived")
}
func TestLRUOverflow(t *testing.T) {
c := NewLRUExpireCache(4)
c.Add("elem1", "1", 10*time.Hour)
c.Add("elem2", "2", 10*time.Hour)
c.Add("elem3", "3", 10*time.Hour)
c.Add("elem4", "4", 10*time.Hour)
c.Add("elem5", "5", 10*time.Hour)
expectNotEntry(t, c, "elem1")
expectEntry(t, c, "elem2", "2")
expectEntry(t, c, "elem3", "3")
expectEntry(t, c, "elem4", "4")
expectEntry(t, c, "elem5", "5")
}

View File

@ -1,196 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package clock
import (
"testing"
"time"
)
var (
_ = Clock(RealClock{})
_ = Clock(&FakeClock{})
_ = Clock(&IntervalClock{})
_ = Timer(&realTimer{})
_ = Timer(&fakeTimer{})
_ = Ticker(&realTicker{})
_ = Ticker(&fakeTicker{})
)
func TestFakeClock(t *testing.T) {
startTime := time.Now()
tc := NewFakeClock(startTime)
tc.Step(time.Second)
now := tc.Now()
if now.Sub(startTime) != time.Second {
t.Errorf("input: %s now=%s gap=%s expected=%s", startTime, now, now.Sub(startTime), time.Second)
}
tt := tc.Now()
tc.SetTime(tt.Add(time.Hour))
if tc.Now().Sub(tt) != time.Hour {
t.Errorf("input: %s now=%s gap=%s expected=%s", tt, tc.Now(), tc.Now().Sub(tt), time.Hour)
}
}
func TestFakeClockSleep(t *testing.T) {
startTime := time.Now()
tc := NewFakeClock(startTime)
tc.Sleep(time.Duration(1) * time.Hour)
now := tc.Now()
if now.Sub(startTime) != time.Hour {
t.Errorf("Fake sleep failed, expected time to advance by one hour, instead, its %v", now.Sub(startTime))
}
}
func TestFakeAfter(t *testing.T) {
tc := NewFakeClock(time.Now())
if tc.HasWaiters() {
t.Errorf("unexpected waiter?")
}
oneSec := tc.After(time.Second)
if !tc.HasWaiters() {
t.Errorf("unexpected lack of waiter?")
}
oneOhOneSec := tc.After(time.Second + time.Millisecond)
twoSec := tc.After(2 * time.Second)
select {
case <-oneSec:
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
}
tc.Step(999 * time.Millisecond)
select {
case <-oneSec:
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
}
tc.Step(time.Millisecond)
select {
case <-oneSec:
// Expected!
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
t.Errorf("unexpected non-channel read")
}
tc.Step(time.Millisecond)
select {
case <-oneSec:
// should not double-trigger!
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
// Expected!
case <-twoSec:
t.Errorf("unexpected channel read")
default:
t.Errorf("unexpected non-channel read")
}
}
func TestFakeTick(t *testing.T) {
tc := NewFakeClock(time.Now())
if tc.HasWaiters() {
t.Errorf("unexpected waiter?")
}
oneSec := tc.NewTicker(time.Second).C()
if !tc.HasWaiters() {
t.Errorf("unexpected lack of waiter?")
}
oneOhOneSec := tc.NewTicker(time.Second + time.Millisecond).C()
twoSec := tc.NewTicker(2 * time.Second).C()
select {
case <-oneSec:
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
}
tc.Step(999 * time.Millisecond) // t=.999
select {
case <-oneSec:
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
}
tc.Step(time.Millisecond) // t=1.000
select {
case <-oneSec:
// Expected!
case <-oneOhOneSec:
t.Errorf("unexpected channel read")
case <-twoSec:
t.Errorf("unexpected channel read")
default:
t.Errorf("unexpected non-channel read")
}
tc.Step(time.Millisecond) // t=1.001
select {
case <-oneSec:
// should not double-trigger!
t.Errorf("unexpected channel read")
case <-oneOhOneSec:
// Expected!
case <-twoSec:
t.Errorf("unexpected channel read")
default:
t.Errorf("unexpected non-channel read")
}
tc.Step(time.Second) // t=2.001
tc.Step(time.Second) // t=3.001
tc.Step(time.Second) // t=4.001
tc.Step(time.Second) // t=5.001
// The one second ticker should not accumulate ticks
accumulatedTicks := 0
drained := false
for !drained {
select {
case <-oneSec:
accumulatedTicks++
default:
drained = true
}
}
if accumulatedTicks != 1 {
t.Errorf("unexpected number of accumulated ticks: %d", accumulatedTicks)
}
}

View File

@ -1,313 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package diff
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"text/tabwriter"
"github.com/davecgh/go-spew/spew"
"k8s.io/apimachinery/pkg/util/validation/field"
)
// StringDiff diffs a and b and returns a human readable diff.
func StringDiff(a, b string) string {
ba := []byte(a)
bb := []byte(b)
out := []byte{}
i := 0
for ; i < len(ba) && i < len(bb); i++ {
if ba[i] != bb[i] {
break
}
out = append(out, ba[i])
}
out = append(out, []byte("\n\nA: ")...)
out = append(out, ba[i:]...)
out = append(out, []byte("\n\nB: ")...)
out = append(out, bb[i:]...)
out = append(out, []byte("\n\n")...)
return string(out)
}
// ObjectDiff writes the two objects out as JSON and prints out the identical part of
// the objects followed by the remaining part of 'a' and finally the remaining part of 'b'.
// For debugging tests.
func ObjectDiff(a, b interface{}) string {
ab, err := json.Marshal(a)
if err != nil {
panic(fmt.Sprintf("a: %v", err))
}
bb, err := json.Marshal(b)
if err != nil {
panic(fmt.Sprintf("b: %v", err))
}
return StringDiff(string(ab), string(bb))
}
// ObjectGoPrintDiff is like ObjectDiff, but uses go-spew to print the objects,
// which shows absolutely everything by recursing into every single pointer
// (go's %#v formatters OTOH stop at a certain point). This is needed when you
// can't figure out why reflect.DeepEqual is returning false and nothing is
// showing you differences. This will.
func ObjectGoPrintDiff(a, b interface{}) string {
s := spew.ConfigState{DisableMethods: true}
return StringDiff(
s.Sprintf("%#v", a),
s.Sprintf("%#v", b),
)
}
func ObjectReflectDiff(a, b interface{}) string {
vA, vB := reflect.ValueOf(a), reflect.ValueOf(b)
if vA.Type() != vB.Type() {
return fmt.Sprintf("type A %T and type B %T do not match", a, b)
}
diffs := objectReflectDiff(field.NewPath("object"), vA, vB)
if len(diffs) == 0 {
return "<no diffs>"
}
out := []string{""}
for _, d := range diffs {
elidedA, elidedB := limit(d.a, d.b, 80)
out = append(out,
fmt.Sprintf("%s:", d.path),
fmt.Sprintf(" a: %s", elidedA),
fmt.Sprintf(" b: %s", elidedB),
)
}
return strings.Join(out, "\n")
}
// limit:
// 1. stringifies aObj and bObj
// 2. elides identical prefixes if either is too long
// 3. elides remaining content from the end if either is too long
func limit(aObj, bObj interface{}, max int) (string, string) {
elidedPrefix := ""
elidedASuffix := ""
elidedBSuffix := ""
a, b := fmt.Sprintf("%#v", aObj), fmt.Sprintf("%#v", bObj)
if aObj != nil && bObj != nil {
if aType, bType := fmt.Sprintf("%T", aObj), fmt.Sprintf("%T", bObj); aType != bType {
a = fmt.Sprintf("%s (%s)", a, aType)
b = fmt.Sprintf("%s (%s)", b, bType)
}
}
for {
switch {
case len(a) > max && len(a) > 4 && len(b) > 4 && a[:4] == b[:4]:
// a is too long, b has data, and the first several characters are the same
elidedPrefix = "..."
a = a[2:]
b = b[2:]
case len(b) > max && len(b) > 4 && len(a) > 4 && a[:4] == b[:4]:
// b is too long, a has data, and the first several characters are the same
elidedPrefix = "..."
a = a[2:]
b = b[2:]
case len(a) > max:
a = a[:max]
elidedASuffix = "..."
case len(b) > max:
b = b[:max]
elidedBSuffix = "..."
default:
// both are short enough
return elidedPrefix + a + elidedASuffix, elidedPrefix + b + elidedBSuffix
}
}
}
func public(s string) bool {
if len(s) == 0 {
return false
}
return s[:1] == strings.ToUpper(s[:1])
}
type diff struct {
path *field.Path
a, b interface{}
}
type orderedDiffs []diff
func (d orderedDiffs) Len() int { return len(d) }
func (d orderedDiffs) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
func (d orderedDiffs) Less(i, j int) bool {
a, b := d[i].path.String(), d[j].path.String()
if a < b {
return true
}
return false
}
func objectReflectDiff(path *field.Path, a, b reflect.Value) []diff {
switch a.Type().Kind() {
case reflect.Struct:
var changes []diff
for i := 0; i < a.Type().NumField(); i++ {
if !public(a.Type().Field(i).Name) {
if reflect.DeepEqual(a.Interface(), b.Interface()) {
continue
}
return []diff{{path: path, a: fmt.Sprintf("%#v", a), b: fmt.Sprintf("%#v", b)}}
}
if sub := objectReflectDiff(path.Child(a.Type().Field(i).Name), a.Field(i), b.Field(i)); len(sub) > 0 {
changes = append(changes, sub...)
}
}
return changes
case reflect.Ptr, reflect.Interface:
if a.IsNil() || b.IsNil() {
switch {
case a.IsNil() && b.IsNil():
return nil
case a.IsNil():
return []diff{{path: path, a: nil, b: b.Interface()}}
default:
return []diff{{path: path, a: a.Interface(), b: nil}}
}
}
return objectReflectDiff(path, a.Elem(), b.Elem())
case reflect.Chan:
if !reflect.DeepEqual(a.Interface(), b.Interface()) {
return []diff{{path: path, a: a.Interface(), b: b.Interface()}}
}
return nil
case reflect.Slice:
lA, lB := a.Len(), b.Len()
l := lA
if lB < lA {
l = lB
}
if lA == lB && lA == 0 {
if a.IsNil() != b.IsNil() {
return []diff{{path: path, a: a.Interface(), b: b.Interface()}}
}
return nil
}
var diffs []diff
for i := 0; i < l; i++ {
if !reflect.DeepEqual(a.Index(i), b.Index(i)) {
diffs = append(diffs, objectReflectDiff(path.Index(i), a.Index(i), b.Index(i))...)
}
}
for i := l; i < lA; i++ {
diffs = append(diffs, diff{path: path.Index(i), a: a.Index(i), b: nil})
}
for i := l; i < lB; i++ {
diffs = append(diffs, diff{path: path.Index(i), a: nil, b: b.Index(i)})
}
return diffs
case reflect.Map:
if reflect.DeepEqual(a.Interface(), b.Interface()) {
return nil
}
aKeys := make(map[interface{}]interface{})
for _, key := range a.MapKeys() {
aKeys[key.Interface()] = a.MapIndex(key).Interface()
}
var missing []diff
for _, key := range b.MapKeys() {
if _, ok := aKeys[key.Interface()]; ok {
delete(aKeys, key.Interface())
if reflect.DeepEqual(a.MapIndex(key).Interface(), b.MapIndex(key).Interface()) {
continue
}
missing = append(missing, objectReflectDiff(path.Key(fmt.Sprintf("%s", key.Interface())), a.MapIndex(key), b.MapIndex(key))...)
continue
}
missing = append(missing, diff{path: path.Key(fmt.Sprintf("%s", key.Interface())), a: nil, b: b.MapIndex(key).Interface()})
}
for key, value := range aKeys {
missing = append(missing, diff{path: path.Key(fmt.Sprintf("%s", key)), a: value, b: nil})
}
if len(missing) == 0 {
missing = append(missing, diff{path: path, a: a.Interface(), b: b.Interface()})
}
sort.Sort(orderedDiffs(missing))
return missing
default:
if reflect.DeepEqual(a.Interface(), b.Interface()) {
return nil
}
if !a.CanInterface() {
return []diff{{path: path, a: fmt.Sprintf("%#v", a), b: fmt.Sprintf("%#v", b)}}
}
return []diff{{path: path, a: a.Interface(), b: b.Interface()}}
}
}
// ObjectGoPrintSideBySide prints a and b as textual dumps side by side,
// enabling easy visual scanning for mismatches.
func ObjectGoPrintSideBySide(a, b interface{}) string {
s := spew.ConfigState{
Indent: " ",
// Extra deep spew.
DisableMethods: true,
}
sA := s.Sdump(a)
sB := s.Sdump(b)
linesA := strings.Split(sA, "\n")
linesB := strings.Split(sB, "\n")
width := 0
for _, s := range linesA {
l := len(s)
if l > width {
width = l
}
}
for _, s := range linesB {
l := len(s)
if l > width {
width = l
}
}
buf := &bytes.Buffer{}
w := tabwriter.NewWriter(buf, width, 0, 1, ' ', 0)
max := len(linesA)
if len(linesB) > max {
max = len(linesB)
}
for i := 0; i < max; i++ {
var a, b string
if i < len(linesA) {
a = linesA[i]
}
if i < len(linesB) {
b = linesB[i]
}
fmt.Fprintf(w, "%s\t%s\n", a, b)
}
w.Flush()
return buf.String()
}

View File

@ -1,148 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package diff
import (
"testing"
)
func TestObjectReflectDiff(t *testing.T) {
type struct1 struct{ A []int }
testCases := map[string]struct {
a, b interface{}
out string
}{
"map": {
a: map[string]int{},
b: map[string]int{},
},
"detect nil map": {
a: map[string]int(nil),
b: map[string]int{},
out: `
object:
a: map[string]int(nil)
b: map[string]int{}`,
},
"detect map changes": {
a: map[string]int{"test": 1, "other": 2},
b: map[string]int{"test": 2, "third": 3},
out: `
object[other]:
a: 2
b: <nil>
object[test]:
a: 1
b: 2
object[third]:
a: <nil>
b: 3`,
},
"nil slice": {a: struct1{A: nil}, b: struct1{A: nil}},
"empty slice": {a: struct1{A: []int{}}, b: struct1{A: []int{}}},
"detect slice changes 1": {a: struct1{A: []int{1}}, b: struct1{A: []int{2}}, out: `
object.A[0]:
a: 1
b: 2`,
},
"detect slice changes 2": {a: struct1{A: []int{}}, b: struct1{A: []int{2}}, out: `
object.A[0]:
a: <nil>
b: 2`,
},
"detect slice changes 3": {a: struct1{A: []int{1}}, b: struct1{A: []int{}}, out: `
object.A[0]:
a: 1
b: <nil>`,
},
"detect nil vs empty slices": {a: struct1{A: nil}, b: struct1{A: []int{}}, out: `
object.A:
a: []int(nil)
b: []int{}`,
},
"display type differences": {a: []interface{}{int64(1)}, b: []interface{}{uint64(1)}, out: `
object[0]:
a: 1 (int64)
b: 0x1 (uint64)`,
},
}
for name, test := range testCases {
expect := test.out
if len(expect) == 0 {
expect = "<no diffs>"
}
if actual := ObjectReflectDiff(test.a, test.b); actual != expect {
t.Errorf("%s: unexpected output: %s", name, actual)
}
}
}
func TestStringDiff(t *testing.T) {
diff := StringDiff("aaabb", "aaacc")
expect := "aaa\n\nA: bb\n\nB: cc\n\n"
if diff != expect {
t.Errorf("diff returned %v", diff)
}
}
func TestLimit(t *testing.T) {
testcases := []struct {
a interface{}
b interface{}
expectA string
expectB string
}{
{
a: `short a`,
b: `short b`,
expectA: `"short a"`,
expectB: `"short b"`,
},
{
a: `short a`,
b: `long b needs truncating`,
expectA: `"short a"`,
expectB: `"long b ne...`,
},
{
a: `long a needs truncating`,
b: `long b needs truncating`,
expectA: `...g a needs ...`,
expectB: `...g b needs ...`,
},
{
a: `long common prefix with different stuff at the end of a`,
b: `long common prefix with different stuff at the end of b`,
expectA: `...end of a"`,
expectB: `...end of b"`,
},
{
a: `long common prefix with different stuff at the end of a`,
b: `long common prefix with different stuff at the end of b which continues`,
expectA: `...of a"`,
expectB: `...of b which...`,
},
}
for _, tc := range testcases {
a, b := limit(tc.a, tc.b, 10)
if a != tc.expectA || b != tc.expectB {
t.Errorf("limit(%q, %q)\n\texpected: %s, %s\n\tgot: %s, %s", tc.a, tc.b, tc.expectA, tc.expectB, a, b)
}
}
}

View File

@ -1,89 +0,0 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package duration
import (
"fmt"
"time"
)
// ShortHumanDuration returns a succint representation of the provided duration
// with limited precision for consumption by humans.
func ShortHumanDuration(d time.Duration) string {
// Allow deviation no more than 2 seconds(excluded) to tolerate machine time
// inconsistence, it can be considered as almost now.
if seconds := int(d.Seconds()); seconds < -1 {
return fmt.Sprintf("<invalid>")
} else if seconds < 0 {
return fmt.Sprintf("0s")
} else if seconds < 60 {
return fmt.Sprintf("%ds", seconds)
} else if minutes := int(d.Minutes()); minutes < 60 {
return fmt.Sprintf("%dm", minutes)
} else if hours := int(d.Hours()); hours < 24 {
return fmt.Sprintf("%dh", hours)
} else if hours < 24*365 {
return fmt.Sprintf("%dd", hours/24)
}
return fmt.Sprintf("%dy", int(d.Hours()/24/365))
}
// HumanDuration returns a succint representation of the provided duration
// with limited precision for consumption by humans. It provides ~2-3 significant
// figures of duration.
func HumanDuration(d time.Duration) string {
// Allow deviation no more than 2 seconds(excluded) to tolerate machine time
// inconsistence, it can be considered as almost now.
if seconds := int(d.Seconds()); seconds < -1 {
return fmt.Sprintf("<invalid>")
} else if seconds < 0 {
return fmt.Sprintf("0s")
} else if seconds < 60*2 {
return fmt.Sprintf("%ds", seconds)
}
minutes := int(d / time.Minute)
if minutes < 10 {
s := int(d/time.Second) % 60
if s == 0 {
return fmt.Sprintf("%dm", minutes)
}
return fmt.Sprintf("%dm%ds", minutes, s)
} else if minutes < 60*3 {
return fmt.Sprintf("%dm", minutes)
}
hours := int(d / time.Hour)
if hours < 8 {
m := int(d/time.Minute) % 60
if m == 0 {
return fmt.Sprintf("%dh", hours)
}
return fmt.Sprintf("%dh%dm", hours, m)
} else if hours < 48 {
return fmt.Sprintf("%dh", hours)
} else if hours < 24*8 {
h := hours % 24
if h == 0 {
return fmt.Sprintf("%dd", hours/24)
}
return fmt.Sprintf("%dd%dh", hours/24, h)
} else if hours < 24*365*2 {
return fmt.Sprintf("%dd", hours/24)
} else if hours < 24*365*8 {
return fmt.Sprintf("%dy%dd", hours/24/365, (hours/24)%365)
}
return fmt.Sprintf("%dy", int(hours/24/365))
}

View File

@ -1,47 +0,0 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package duration
import (
"testing"
"time"
)
func TestHumanDuration(t *testing.T) {
tests := []struct {
d time.Duration
want string
}{
{d: time.Second, want: "1s"},
{d: 70 * time.Second, want: "70s"},
{d: 190 * time.Second, want: "3m10s"},
{d: 70 * time.Minute, want: "70m"},
{d: 47 * time.Hour, want: "47h"},
{d: 49 * time.Hour, want: "2d1h"},
{d: (8*24 + 2) * time.Hour, want: "8d"},
{d: (367 * 24) * time.Hour, want: "367d"},
{d: (365*2*24 + 25) * time.Hour, want: "2y1d"},
{d: (365*8*24 + 2) * time.Hour, want: "8y"},
}
for _, tt := range tests {
t.Run(tt.d.String(), func(t *testing.T) {
if got := HumanDuration(tt.d); got != tt.want {
t.Errorf("HumanDuration() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -1,368 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package errors
import (
"fmt"
"reflect"
"sort"
"testing"
)
func TestEmptyAggregate(t *testing.T) {
var slice []error
var agg Aggregate
var err error
agg = NewAggregate(slice)
if agg != nil {
t.Errorf("expected nil, got %#v", agg)
}
err = NewAggregate(slice)
if err != nil {
t.Errorf("expected nil, got %#v", err)
}
// This is not normally possible, but pedantry demands I test it.
agg = aggregate(slice) // empty aggregate
if s := agg.Error(); s != "" {
t.Errorf("expected empty string, got %q", s)
}
if s := agg.Errors(); len(s) != 0 {
t.Errorf("expected empty slice, got %#v", s)
}
err = agg.(error)
if s := err.Error(); s != "" {
t.Errorf("expected empty string, got %q", s)
}
}
func TestAggregateWithNil(t *testing.T) {
var slice []error
slice = []error{nil}
var agg Aggregate
var err error
agg = NewAggregate(slice)
if agg != nil {
t.Errorf("expected nil, got %#v", agg)
}
err = NewAggregate(slice)
if err != nil {
t.Errorf("expected nil, got %#v", err)
}
// Append a non-nil error
slice = append(slice, fmt.Errorf("err"))
agg = NewAggregate(slice)
if agg == nil {
t.Errorf("expected non-nil")
}
if s := agg.Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
if s := agg.Errors(); len(s) != 1 {
t.Errorf("expected one-element slice, got %#v", s)
}
if s := agg.Errors()[0].Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
err = agg.(error)
if err == nil {
t.Errorf("expected non-nil")
}
if s := err.Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
}
func TestSingularAggregate(t *testing.T) {
var slice []error = []error{fmt.Errorf("err")}
var agg Aggregate
var err error
agg = NewAggregate(slice)
if agg == nil {
t.Errorf("expected non-nil")
}
if s := agg.Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
if s := agg.Errors(); len(s) != 1 {
t.Errorf("expected one-element slice, got %#v", s)
}
if s := agg.Errors()[0].Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
err = agg.(error)
if err == nil {
t.Errorf("expected non-nil")
}
if s := err.Error(); s != "err" {
t.Errorf("expected 'err', got %q", s)
}
}
func TestPluralAggregate(t *testing.T) {
var slice []error = []error{fmt.Errorf("abc"), fmt.Errorf("123")}
var agg Aggregate
var err error
agg = NewAggregate(slice)
if agg == nil {
t.Errorf("expected non-nil")
}
if s := agg.Error(); s != "[abc, 123]" {
t.Errorf("expected '[abc, 123]', got %q", s)
}
if s := agg.Errors(); len(s) != 2 {
t.Errorf("expected two-elements slice, got %#v", s)
}
if s := agg.Errors()[0].Error(); s != "abc" {
t.Errorf("expected '[abc, 123]', got %q", s)
}
err = agg.(error)
if err == nil {
t.Errorf("expected non-nil")
}
if s := err.Error(); s != "[abc, 123]" {
t.Errorf("expected '[abc, 123]', got %q", s)
}
}
func TestFilterOut(t *testing.T) {
testCases := []struct {
err error
filter []Matcher
expected error
}{
{
nil,
[]Matcher{},
nil,
},
{
aggregate{},
[]Matcher{},
nil,
},
{
aggregate{fmt.Errorf("abc")},
[]Matcher{},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{fmt.Errorf("abc")},
[]Matcher{func(err error) bool { return false }},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{fmt.Errorf("abc")},
[]Matcher{func(err error) bool { return true }},
nil,
},
{
aggregate{fmt.Errorf("abc")},
[]Matcher{func(err error) bool { return false }, func(err error) bool { return false }},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{fmt.Errorf("abc")},
[]Matcher{func(err error) bool { return false }, func(err error) bool { return true }},
nil,
},
{
aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")},
[]Matcher{func(err error) bool { return err.Error() == "def" }},
aggregate{fmt.Errorf("abc"), fmt.Errorf("ghi")},
},
{
aggregate{aggregate{fmt.Errorf("abc")}},
[]Matcher{},
aggregate{aggregate{fmt.Errorf("abc")}},
},
{
aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}},
[]Matcher{},
aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}},
},
{
aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}},
[]Matcher{func(err error) bool { return err.Error() == "def" }},
aggregate{aggregate{fmt.Errorf("abc")}},
},
}
for i, testCase := range testCases {
err := FilterOut(testCase.err, testCase.filter...)
if !reflect.DeepEqual(testCase.expected, err) {
t.Errorf("%d: expected %v, got %v", i, testCase.expected, err)
}
}
}
func TestFlatten(t *testing.T) {
testCases := []struct {
agg Aggregate
expected Aggregate
}{
{
nil,
nil,
},
{
aggregate{},
nil,
},
{
aggregate{fmt.Errorf("abc")},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")},
aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")},
},
{
aggregate{aggregate{fmt.Errorf("abc")}},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{aggregate{aggregate{fmt.Errorf("abc")}}},
aggregate{fmt.Errorf("abc")},
},
{
aggregate{aggregate{fmt.Errorf("abc"), aggregate{fmt.Errorf("def")}}},
aggregate{fmt.Errorf("abc"), fmt.Errorf("def")},
},
{
aggregate{aggregate{aggregate{fmt.Errorf("abc")}, fmt.Errorf("def"), aggregate{fmt.Errorf("ghi")}}},
aggregate{fmt.Errorf("abc"), fmt.Errorf("def"), fmt.Errorf("ghi")},
},
}
for i, testCase := range testCases {
agg := Flatten(testCase.agg)
if !reflect.DeepEqual(testCase.expected, agg) {
t.Errorf("%d: expected %v, got %v", i, testCase.expected, agg)
}
}
}
func TestCreateAggregateFromMessageCountMap(t *testing.T) {
testCases := []struct {
name string
mcm MessageCountMap
expected Aggregate
}{
{
"input has single instance of one message",
MessageCountMap{"abc": 1},
aggregate{fmt.Errorf("abc")},
},
{
"input has multiple messages",
MessageCountMap{"abc": 2, "ghi": 1},
aggregate{fmt.Errorf("abc (repeated 2 times)"), fmt.Errorf("ghi")},
},
{
"input has multiple messages",
MessageCountMap{"ghi": 1, "abc": 2},
aggregate{fmt.Errorf("abc (repeated 2 times)"), fmt.Errorf("ghi")},
},
}
var expected, agg []error
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
if testCase.expected != nil {
expected = testCase.expected.Errors()
sort.Slice(expected, func(i, j int) bool { return expected[i].Error() < expected[j].Error() })
}
if testCase.mcm != nil {
agg = CreateAggregateFromMessageCountMap(testCase.mcm).Errors()
sort.Slice(agg, func(i, j int) bool { return agg[i].Error() < agg[j].Error() })
}
if !reflect.DeepEqual(expected, agg) {
t.Errorf("expected %v, got %v", expected, agg)
}
})
}
}
func TestAggregateGoroutines(t *testing.T) {
testCases := []struct {
errs []error
expected map[string]bool // can't compare directly to Aggregate due to non-deterministic ordering
}{
{
[]error{},
nil,
},
{
[]error{nil},
nil,
},
{
[]error{nil, nil},
nil,
},
{
[]error{fmt.Errorf("1")},
map[string]bool{"1": true},
},
{
[]error{fmt.Errorf("1"), nil},
map[string]bool{"1": true},
},
{
[]error{fmt.Errorf("1"), fmt.Errorf("267")},
map[string]bool{"1": true, "267": true},
},
{
[]error{fmt.Errorf("1"), nil, fmt.Errorf("1234")},
map[string]bool{"1": true, "1234": true},
},
{
[]error{nil, fmt.Errorf("1"), nil, fmt.Errorf("1234"), fmt.Errorf("22")},
map[string]bool{"1": true, "1234": true, "22": true},
},
}
for i, testCase := range testCases {
funcs := make([]func() error, len(testCase.errs))
for i := range testCase.errs {
err := testCase.errs[i]
funcs[i] = func() error { return err }
}
agg := AggregateGoroutines(funcs...)
if agg == nil {
if len(testCase.expected) > 0 {
t.Errorf("%d: expected %v, got nil", i, testCase.expected)
}
continue
}
if len(agg.Errors()) != len(testCase.expected) {
t.Errorf("%d: expected %d errors in aggregate, got %v", i, len(testCase.expected), agg)
continue
}
for _, err := range agg.Errors() {
if !testCase.expected[err.Error()] {
t.Errorf("%d: expected %v, got aggregate containing %v", i, testCase.expected, err)
}
}
}
}

View File

@ -1,176 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package framer
import (
"bytes"
"io"
"io/ioutil"
"testing"
)
func TestRead(t *testing.T) {
data := []byte{
0x00, 0x00, 0x00, 0x04,
0x01, 0x02, 0x03, 0x04,
0x00, 0x00, 0x00, 0x03,
0x05, 0x06, 0x07,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
0x08,
}
b := bytes.NewBuffer(data)
r := NewLengthDelimitedFrameReader(ioutil.NopCloser(b))
buf := make([]byte, 1)
if n, err := r.Read(buf); err != io.ErrShortBuffer && n != 1 && bytes.Equal(buf, []byte{0x01}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
if n, err := r.Read(buf); err != io.ErrShortBuffer && n != 1 && bytes.Equal(buf, []byte{0x02}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read the remaining frame
buf = make([]byte, 2)
if n, err := r.Read(buf); err != nil && n != 2 && bytes.Equal(buf, []byte{0x03, 0x04}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read with buffer equal to frame
buf = make([]byte, 3)
if n, err := r.Read(buf); err != nil && n != 3 && bytes.Equal(buf, []byte{0x05, 0x06, 0x07}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read empty frame
buf = make([]byte, 3)
if n, err := r.Read(buf); err != nil && n != 0 && bytes.Equal(buf, []byte{}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read with larger buffer than frame
buf = make([]byte, 3)
if n, err := r.Read(buf); err != nil && n != 1 && bytes.Equal(buf, []byte{0x08}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read EOF
if n, err := r.Read(buf); err != io.EOF && n != 0 {
t.Fatalf("unexpected: %v %d", err, n)
}
}
func TestReadLarge(t *testing.T) {
data := []byte{
0x00, 0x00, 0x00, 0x04,
0x01, 0x02, 0x03, 0x04,
0x00, 0x00, 0x00, 0x03,
0x05, 0x06, 0x07,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
0x08,
}
b := bytes.NewBuffer(data)
r := NewLengthDelimitedFrameReader(ioutil.NopCloser(b))
buf := make([]byte, 40)
if n, err := r.Read(buf); err != nil && n != 4 && bytes.Equal(buf, []byte{0x01, 0x02, 0x03, 0x04}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
if n, err := r.Read(buf); err != nil && n != 3 && bytes.Equal(buf, []byte{0x05, 0x06, 0x7}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
if n, err := r.Read(buf); err != nil && n != 0 && bytes.Equal(buf, []byte{}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
if n, err := r.Read(buf); err != nil && n != 1 && bytes.Equal(buf, []byte{0x08}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read EOF
if n, err := r.Read(buf); err != io.EOF && n != 0 {
t.Fatalf("unexpected: %v %d", err, n)
}
}
func TestReadInvalidFrame(t *testing.T) {
data := []byte{
0x00, 0x00, 0x00, 0x04,
0x01, 0x02,
}
b := bytes.NewBuffer(data)
r := NewLengthDelimitedFrameReader(ioutil.NopCloser(b))
buf := make([]byte, 1)
if n, err := r.Read(buf); err != io.ErrShortBuffer && n != 1 && bytes.Equal(buf, []byte{0x01}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read the remaining frame
buf = make([]byte, 3)
if n, err := r.Read(buf); err != io.ErrUnexpectedEOF && n != 1 && bytes.Equal(buf, []byte{0x02}) {
t.Fatalf("unexpected: %v %d %v", err, n, buf)
}
// read EOF
if n, err := r.Read(buf); err != io.EOF && n != 0 {
t.Fatalf("unexpected: %v %d", err, n)
}
}
func TestJSONFrameReader(t *testing.T) {
b := bytes.NewBufferString("{\"test\":true}\n1\n[\"a\"]")
r := NewJSONFramedReader(ioutil.NopCloser(b))
buf := make([]byte, 20)
if n, err := r.Read(buf); err != nil || n != 13 || string(buf[:n]) != `{"test":true}` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != nil || n != 1 || string(buf[:n]) != `1` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != nil || n != 5 || string(buf[:n]) != `["a"]` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.EOF || n != 0 {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
}
func TestJSONFrameReaderShortBuffer(t *testing.T) {
b := bytes.NewBufferString("{\"test\":true}\n1\n[\"a\"]")
r := NewJSONFramedReader(ioutil.NopCloser(b))
buf := make([]byte, 3)
if n, err := r.Read(buf); err != io.ErrShortBuffer || n != 3 || string(buf[:n]) != `{"t` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.ErrShortBuffer || n != 3 || string(buf[:n]) != `est` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.ErrShortBuffer || n != 3 || string(buf[:n]) != `":t` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.ErrShortBuffer || n != 3 || string(buf[:n]) != `rue` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != nil || n != 1 || string(buf[:n]) != `}` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != nil || n != 1 || string(buf[:n]) != `1` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.ErrShortBuffer || n != 3 || string(buf[:n]) != `["a` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != nil || n != 2 || string(buf[:n]) != `"]` {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
if n, err := r.Read(buf); err != io.EOF || n != 0 {
t.Fatalf("unexpected: %v %d %q", err, n, buf)
}
}

View File

@ -1,19 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package httpstream adds multiplexed streaming support to HTTP requests and
// responses via connection upgrades.
package httpstream // import "k8s.io/apimachinery/pkg/util/httpstream"

View File

@ -1,149 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httpstream
import (
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderProtocolVersion = "X-Stream-Protocol-Version"
HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
)
// NewStreamHandler defines a function that is called when a new Stream is
// received. If no error is returned, the Stream is accepted; otherwise,
// the stream is rejected. After the reply frame has been sent, replySent is closed.
type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error
// NoOpNewStreamHandler is a stream handler that accepts a new stream and
// performs no other logic.
func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil }
// Dialer knows how to open a streaming connection to a server.
type Dialer interface {
// Dial opens a streaming connection to a server using one of the protocols
// specified (in order of most preferred to least preferred).
Dial(protocols ...string) (Connection, string, error)
}
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
// HTTP requests to support multiplexed bidirectional streams. After RoundTrip()
// is invoked, if the upgrade is successful, clients may retrieve the upgraded
// connection by calling UpgradeRoundTripper.Connection().
type UpgradeRoundTripper interface {
http.RoundTripper
// NewConnection validates the response and creates a new Connection.
NewConnection(resp *http.Response) (Connection, error)
}
// ResponseUpgrader knows how to upgrade HTTP requests and responses to
// add streaming support to them.
type ResponseUpgrader interface {
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called asynchronously whenever the
// other end of the upgraded connection creates a new stream.
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
}
// Connection represents an upgraded HTTP connection.
type Connection interface {
// CreateStream creates a new Stream with the supplied headers.
CreateStream(headers http.Header) (Stream, error)
// Close resets all streams and closes the connection.
Close() error
// CloseChan returns a channel that is closed when the underlying connection is closed.
CloseChan() <-chan bool
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
SetIdleTimeout(timeout time.Duration)
}
// Stream represents a bidirectional communications channel that is part of an
// upgraded connection.
type Stream interface {
io.ReadWriteCloser
// Reset closes both directions of the stream, indicating that neither client
// or server can use it any more.
Reset() error
// Headers returns the headers used to create the stream.
Headers() http.Header
// Identifier returns the stream's ID.
Identifier() uint32
}
// IsUpgradeRequest returns true if the given request is a connection upgrade request
func IsUpgradeRequest(req *http.Request) bool {
for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] {
if strings.Contains(strings.ToLower(h), strings.ToLower(HeaderUpgrade)) {
return true
}
}
return false
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
// Handshake performs a subprotocol negotiation. If the client did request a
// subprotocol, Handshake will select the first common value found in
// serverProtocols. If a match is found, Handshake adds a response header
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
// returned, along with a response header containing the list of protocols the
// server can accept.
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
if len(clientProtocols) == 0 {
// Kube 1.0 clients didn't support subprotocol negotiation.
// TODO require clientProtocols once Kube 1.0 is no longer supported
return "", nil
}
if len(serverProtocols) == 0 {
// Kube 1.0 servers didn't support subprotocol negotiation. This is mainly for testing.
// TODO require serverProtocols once Kube 1.0 is no longer supported
return "", nil
}
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
if len(negotiatedProtocol) == 0 {
for i := range serverProtocols {
w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
}
err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
http.Error(w, err.Error(), http.StatusForbidden)
return "", err
}
w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
return negotiatedProtocol, nil
}

View File

@ -1,125 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httpstream
import (
"net/http"
"reflect"
"testing"
)
type responseWriter struct {
header http.Header
statusCode *int
}
func newResponseWriter() *responseWriter {
return &responseWriter{
header: make(http.Header),
}
}
func (r *responseWriter) Header() http.Header {
return r.header
}
func (r *responseWriter) WriteHeader(code int) {
r.statusCode = &code
}
func (r *responseWriter) Write([]byte) (int, error) {
return 0, nil
}
func TestHandshake(t *testing.T) {
tests := map[string]struct {
clientProtocols []string
serverProtocols []string
expectedProtocol string
expectError bool
}{
"no client protocols": {
clientProtocols: []string{},
serverProtocols: []string{"a", "b"},
expectedProtocol: "",
},
"no common protocol": {
clientProtocols: []string{"c"},
serverProtocols: []string{"a", "b"},
expectedProtocol: "",
expectError: true,
},
"common protocol": {
clientProtocols: []string{"b"},
serverProtocols: []string{"a", "b"},
expectedProtocol: "b",
},
}
for name, test := range tests {
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
if err != nil {
t.Fatalf("%s: error creating request: %v", name, err)
}
for _, p := range test.clientProtocols {
req.Header.Add(HeaderProtocolVersion, p)
}
w := newResponseWriter()
negotiated, err := Handshake(req, w, test.serverProtocols)
// verify negotiated protocol
if e, a := test.expectedProtocol, negotiated; e != a {
t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
}
if test.expectError {
if err == nil {
t.Errorf("%s: expected error but did not get one", name)
}
if w.statusCode == nil {
t.Errorf("%s: expected w.statusCode to be set", name)
} else if e, a := http.StatusForbidden, *w.statusCode; e != a {
t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
}
if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
}
continue
}
if !test.expectError && err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
if w.statusCode != nil {
t.Errorf("%s: unexpected non-nil w.statusCode: %d", name, w.statusCode)
}
if len(test.expectedProtocol) == 0 {
if len(w.Header()[HeaderProtocolVersion]) > 0 {
t.Errorf("%s: unexpected protocol version response header: %s", name, w.Header()[HeaderProtocolVersion])
}
continue
}
// verify response headers
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
}
}
}

View File

@ -1,145 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"net"
"net/http"
"sync"
"time"
"github.com/docker/spdystream"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/klog"
)
// connection maintains state about a spdystream.Connection and its associated
// streams.
type connection struct {
conn *spdystream.Connection
streams []httpstream.Stream
streamLock sync.Mutex
newStreamHandler httpstream.NewStreamHandler
}
// NewClientConnection creates a new SPDY client connection.
func NewClientConnection(conn net.Conn) (httpstream.Connection, error) {
spdyConn, err := spdystream.NewConnection(conn, false)
if err != nil {
defer conn.Close()
return nil, err
}
return newConnection(spdyConn, httpstream.NoOpNewStreamHandler), nil
}
// NewServerConnection creates a new SPDY server connection. newStreamHandler
// will be invoked when the server receives a newly created stream from the
// client.
func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, error) {
spdyConn, err := spdystream.NewConnection(conn, true)
if err != nil {
defer conn.Close()
return nil, err
}
return newConnection(spdyConn, newStreamHandler), nil
}
// newConnection returns a new connection wrapping conn. newStreamHandler
// will be invoked when the server receives a newly created stream from the
// client.
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
c := &connection{conn: conn, newStreamHandler: newStreamHandler}
go conn.Serve(c.newSpdyStream)
return c
}
// createStreamResponseTimeout indicates how long to wait for the other side to
// acknowledge the new stream before timing out.
const createStreamResponseTimeout = 30 * time.Second
// Close first sends a reset for all of the connection's streams, and then
// closes the underlying spdystream.Connection.
func (c *connection) Close() error {
c.streamLock.Lock()
for _, s := range c.streams {
// calling Reset instead of Close ensures that all streams are fully torn down
s.Reset()
}
c.streams = make([]httpstream.Stream, 0)
c.streamLock.Unlock()
// now that all streams are fully torn down, it's safe to call close on the underlying connection,
// which should be able to terminate immediately at this point, instead of waiting for any
// remaining graceful stream termination.
return c.conn.Close()
}
// CreateStream creates a new stream with the specified headers and registers
// it with the connection.
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
stream, err := c.conn.CreateStream(headers, nil, false)
if err != nil {
return nil, err
}
if err = stream.WaitTimeout(createStreamResponseTimeout); err != nil {
return nil, err
}
c.registerStream(stream)
return stream, nil
}
// registerStream adds the stream s to the connection's list of streams that
// it owns.
func (c *connection) registerStream(s httpstream.Stream) {
c.streamLock.Lock()
c.streams = append(c.streams, s)
c.streamLock.Unlock()
}
// CloseChan returns a channel that, when closed, indicates that the underlying
// spdystream.Connection has been closed.
func (c *connection) CloseChan() <-chan bool {
return c.conn.CloseChan()
}
// newSpdyStream is the internal new stream handler used by spdystream.Connection.Serve.
// It calls connection's newStreamHandler, giving it the opportunity to accept or reject
// the stream. If newStreamHandler returns an error, the stream is rejected. If not, the
// stream is accepted and registered with the connection.
func (c *connection) newSpdyStream(stream *spdystream.Stream) {
replySent := make(chan struct{})
err := c.newStreamHandler(stream, replySent)
rejectStream := (err != nil)
if rejectStream {
klog.Warningf("Stream rejected: %v", err)
stream.Reset()
return
}
c.registerStream(stream)
stream.SendReply(http.Header{}, rejectStream)
close(replySent)
}
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
func (c *connection) SetIdleTimeout(timeout time.Duration) {
c.conn.SetIdleTimeout(timeout)
}

View File

@ -1,164 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"io"
"net"
"net/http"
"sync"
"testing"
"time"
"k8s.io/apimachinery/pkg/util/httpstream"
)
func runProxy(t *testing.T, backendUrl string, proxyUrl chan<- string, proxyDone chan<- struct{}) {
listener, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
defer listener.Close()
proxyUrl <- listener.Addr().String()
clientConn, err := listener.Accept()
if err != nil {
t.Errorf("proxy: error accepting client connection: %v", err)
return
}
backendConn, err := net.Dial("tcp4", backendUrl)
if err != nil {
t.Errorf("proxy: error dialing backend: %v", err)
return
}
defer backendConn.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(backendConn, clientConn)
}()
go func() {
defer wg.Done()
io.Copy(clientConn, backendConn)
}()
wg.Wait()
proxyDone <- struct{}{}
}
func runServer(t *testing.T, backendUrl chan<- string, serverDone chan<- struct{}) {
listener, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatalf("server: error listening: %v", err)
}
defer listener.Close()
backendUrl <- listener.Addr().String()
conn, err := listener.Accept()
if err != nil {
t.Errorf("server: error accepting connection: %v", err)
return
}
streamChan := make(chan httpstream.Stream)
replySentChan := make(chan (<-chan struct{}))
spdyConn, err := NewServerConnection(conn, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamChan <- stream
replySentChan <- replySent
return nil
})
if err != nil {
t.Errorf("server: error creating spdy connection: %v", err)
return
}
stream := <-streamChan
replySent := <-replySentChan
<-replySent
buf := make([]byte, 1)
_, err = stream.Read(buf)
if err != io.EOF {
t.Errorf("server: unexpected read error: %v", err)
return
}
<-spdyConn.CloseChan()
raw := spdyConn.(*connection).conn
if err := raw.Wait(15 * time.Second); err != nil {
t.Errorf("server: timed out waiting for connection closure: %v", err)
}
serverDone <- struct{}{}
}
func TestConnectionCloseIsImmediateThroughAProxy(t *testing.T) {
serverDone := make(chan struct{})
backendUrlChan := make(chan string)
go runServer(t, backendUrlChan, serverDone)
backendUrl := <-backendUrlChan
proxyDone := make(chan struct{})
proxyUrlChan := make(chan string)
go runProxy(t, backendUrl, proxyUrlChan, proxyDone)
proxyUrl := <-proxyUrlChan
conn, err := net.Dial("tcp4", proxyUrl)
if err != nil {
t.Fatalf("client: error connecting to proxy: %v", err)
}
spdyConn, err := NewClientConnection(conn)
if err != nil {
t.Fatalf("client: error creating spdy connection: %v", err)
}
if _, err := spdyConn.CreateStream(http.Header{}); err != nil {
t.Fatalf("client: error creating stream: %v", err)
}
spdyConn.Close()
raw := spdyConn.(*connection).conn
if err := raw.Wait(15 * time.Second); err != nil {
t.Fatalf("client: timed out waiting for connection closure: %v", err)
}
expired := time.NewTimer(15 * time.Second)
defer expired.Stop()
i := 0
for {
select {
case <-expired.C:
t.Fatalf("timed out waiting for proxy and/or server closure")
case <-serverDone:
i++
case <-proxyDone:
i++
}
if i == 2 {
break
}
}
}

View File

@ -1,335 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports
// multiplexed streams. After RoundTrip() is invoked, Conn will be set
// and usable. SpdyRoundTripper implements the UpgradeRoundTripper interface.
type SpdyRoundTripper struct {
//tlsConfig holds the TLS configuration settings to use when connecting
//to the remote server.
tlsConfig *tls.Config
/* TODO according to http://golang.org/pkg/net/http/#RoundTripper, a RoundTripper
must be safe for use by multiple concurrent goroutines. If this is absolutely
necessary, we could keep a map from http.Request to net.Conn. In practice,
a client will create an http.Client, set the transport to a new insteace of
SpdyRoundTripper, and use it a single time, so this hopefully won't be an issue.
*/
// conn is the underlying network connection to the remote server.
conn net.Conn
// Dialer is the dialer used to connect. Used if non-nil.
Dialer *net.Dialer
// proxier knows which proxy to use given a request, defaults to http.ProxyFromEnvironment
// Used primarily for mocking the proxy discovery in tests.
proxier func(req *http.Request) (*url.URL, error)
// followRedirects indicates if the round tripper should examine responses for redirects and
// follow them.
followRedirects bool
// requireSameHostRedirects restricts redirect following to only follow redirects to the same host
// as the original request.
requireSameHostRedirects bool
}
var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{}
var _ httpstream.UpgradeRoundTripper = &SpdyRoundTripper{}
var _ utilnet.Dialer = &SpdyRoundTripper{}
// NewRoundTripper creates a new SpdyRoundTripper that will use
// the specified tlsConfig.
func NewRoundTripper(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool) httpstream.UpgradeRoundTripper {
return NewSpdyRoundTripper(tlsConfig, followRedirects, requireSameHostRedirects)
}
// NewSpdyRoundTripper creates a new SpdyRoundTripper that will use
// the specified tlsConfig. This function is mostly meant for unit tests.
func NewSpdyRoundTripper(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool) *SpdyRoundTripper {
return &SpdyRoundTripper{
tlsConfig: tlsConfig,
followRedirects: followRedirects,
requireSameHostRedirects: requireSameHostRedirects,
}
}
// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
// proxying with a spdy roundtripper.
func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config {
return s.tlsConfig
}
// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer.
func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
conn, err := s.dial(req)
if err != nil {
return nil, err
}
if err := req.Write(conn); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// dial dials the host specified by req, using TLS if appropriate, optionally
// using a proxy server if one is configured via environment variables.
func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
proxier := s.proxier
if proxier == nil {
proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
proxyURL, err := proxier(req)
if err != nil {
return nil, err
}
if proxyURL == nil {
return s.dialWithoutProxy(req.Context(), req.URL)
}
// ensure we use a canonical host with proxyReq
targetHost := netutil.CanonicalAddr(req.URL)
// proxying logic adapted from http://blog.h6t.eu/post/74098062923/golang-websocket-with-http-proxy-support
proxyReq := http.Request{
Method: "CONNECT",
URL: &url.URL{},
Host: targetHost,
}
if pa := s.proxyAuth(proxyURL); pa != "" {
proxyReq.Header = http.Header{}
proxyReq.Header.Set("Proxy-Authorization", pa)
}
proxyDialConn, err := s.dialWithoutProxy(req.Context(), proxyURL)
if err != nil {
return nil, err
}
proxyClientConn := httputil.NewProxyClientConn(proxyDialConn, nil)
_, err = proxyClientConn.Do(&proxyReq)
if err != nil && err != httputil.ErrPersistEOF {
return nil, err
}
rwc, _ := proxyClientConn.Hijack()
if req.URL.Scheme != "https" {
return rwc, nil
}
host, _, err := net.SplitHostPort(targetHost)
if err != nil {
return nil, err
}
tlsConfig := s.tlsConfig
switch {
case tlsConfig == nil:
tlsConfig = &tls.Config{ServerName: host}
case len(tlsConfig.ServerName) == 0:
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
}
tlsConn := tls.Client(rwc, tlsConfig)
// need to manually call Handshake() so we can call VerifyHostname() below
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
// Return if we were configured to skip validation
if tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
if err := tlsConn.VerifyHostname(tlsConfig.ServerName); err != nil {
return nil, err
}
return tlsConn, nil
}
// dialWithoutProxy dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
if url.Scheme == "http" {
if s.Dialer == nil {
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
} else {
return s.Dialer.DialContext(ctx, "tcp", dialAddr)
}
}
// TODO validate the TLSClientConfig is set up?
var conn *tls.Conn
var err error
if s.Dialer == nil {
conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig)
} else {
conn, err = tls.DialWithDialer(s.Dialer, "tcp", dialAddr, s.tlsConfig)
}
if err != nil {
return nil, err
}
// Return if we were configured to skip validation
if s.tlsConfig != nil && s.tlsConfig.InsecureSkipVerify {
return conn, nil
}
host, _, err := net.SplitHostPort(dialAddr)
if err != nil {
return nil, err
}
if s.tlsConfig != nil && len(s.tlsConfig.ServerName) > 0 {
host = s.tlsConfig.ServerName
}
err = conn.VerifyHostname(host)
if err != nil {
return nil, err
}
return conn, nil
}
// proxyAuth returns, for a given proxy URL, the value to be used for the Proxy-Authorization header
func (s *SpdyRoundTripper) proxyAuth(proxyURL *url.URL) string {
if proxyURL == nil || proxyURL.User == nil {
return ""
}
credentials := proxyURL.User.String()
encodedAuth := base64.StdEncoding.EncodeToString([]byte(credentials))
return fmt.Sprintf("Basic %s", encodedAuth)
}
// RoundTrip executes the Request and upgrades it. After a successful upgrade,
// clients may call SpdyRoundTripper.Connection() to retrieve the upgraded
// connection.
func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
header := utilnet.CloneHeader(req.Header)
header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
header.Add(httpstream.HeaderUpgrade, HeaderSpdy31)
var (
conn net.Conn
rawResponse []byte
err error
)
if s.followRedirects {
conn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, req.URL, header, req.Body, s, s.requireSameHostRedirects)
} else {
clone := utilnet.CloneRequest(req)
clone.Header = header
conn, err = s.Dial(clone)
}
if err != nil {
return nil, err
}
responseReader := bufio.NewReader(
io.MultiReader(
bytes.NewBuffer(rawResponse),
conn,
),
)
resp, err := http.ReadResponse(responseReader, nil)
if err != nil {
if conn != nil {
conn.Close()
}
return nil, err
}
s.conn = conn
return resp, nil
}
// NewConnection validates the upgrade response, creating and returning a new
// httpstream.Connection if there were no errors.
func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
connectionHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderUpgrade))
if (resp.StatusCode != http.StatusSwitchingProtocols) || !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
defer resp.Body.Close()
responseError := ""
responseErrorBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
responseError = "unable to read error from server response"
} else {
// TODO: I don't belong here, I should be abstracted from this class
if obj, _, err := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); err == nil {
if status, ok := obj.(*metav1.Status); ok {
return nil, &apierrors.StatusError{ErrStatus: *status}
}
}
responseError = string(responseErrorBytes)
responseError = strings.TrimSpace(responseError)
}
return nil, fmt.Errorf("unable to upgrade connection: %s", responseError)
}
return NewClientConnection(s.conn)
}
// statusScheme is private scheme for the decoding here until someone fixes the TODO in NewConnection
var statusScheme = runtime.NewScheme()
// ParameterCodec knows about query parameters used with the meta v1 API spec.
var statusCodecs = serializer.NewCodecFactory(statusScheme)
func init() {
statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion,
&metav1.Status{},
)
}

View File

@ -1,529 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"github.com/elazarl/goproxy"
"k8s.io/apimachinery/pkg/util/httpstream"
)
// be sure to unset environment variable https_proxy (if exported) before testing, otherwise the testing will fail unexpectedly.
func TestRoundTripAndNewConnection(t *testing.T) {
for _, redirect := range []bool{false, true} {
t.Run(fmt.Sprintf("redirect = %t", redirect), func(t *testing.T) {
localhostPool := x509.NewCertPool()
if !localhostPool.AppendCertsFromPEM(localhostCert) {
t.Errorf("error setting up localhostCert pool")
}
httpsServerInvalidHostname := func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(exampleCert, exampleKey)
if err != nil {
t.Errorf("https (invalid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
}
httpsServerValidHostname := func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
}
testCases := map[string]struct {
serverFunc func(http.Handler) *httptest.Server
proxyServerFunc func(http.Handler) *httptest.Server
proxyAuth *url.Userinfo
clientTLS *tls.Config
serverConnectionHeader string
serverUpgradeHeader string
serverStatusCode int
shouldError bool
}{
"no headers": {
serverFunc: httptest.NewServer,
serverConnectionHeader: "",
serverUpgradeHeader: "",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true,
},
"no upgrade header": {
serverFunc: httptest.NewServer,
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true,
},
"no connection header": {
serverFunc: httptest.NewServer,
serverConnectionHeader: "",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true,
},
"no switching protocol status code": {
serverFunc: httptest.NewServer,
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusForbidden,
shouldError: true,
},
"http": {
serverFunc: httptest.NewServer,
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"https (invalid hostname + InsecureSkipVerify)": {
serverFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: true},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"https (invalid hostname + hostname verification)": {
serverFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: false},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true,
},
"https (valid hostname + RootCAs)": {
serverFunc: httpsServerValidHostname,
clientTLS: &tls.Config{RootCAs: localhostPool},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied http->http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httptest.NewServer,
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https (invalid hostname + InsecureSkipVerify) -> http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: true},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https with auth (invalid hostname + InsecureSkipVerify) -> http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httpsServerInvalidHostname,
proxyAuth: url.UserPassword("proxyuser", "proxypasswd"),
clientTLS: &tls.Config{InsecureSkipVerify: true},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https (invalid hostname + hostname verification) -> http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: false},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true, // fails because the client doesn't trust the proxy
},
"proxied https (valid hostname + RootCAs) -> http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httpsServerValidHostname,
clientTLS: &tls.Config{RootCAs: localhostPool},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https with auth (valid hostname + RootCAs) -> http": {
serverFunc: httptest.NewServer,
proxyServerFunc: httpsServerValidHostname,
proxyAuth: url.UserPassword("proxyuser", "proxypasswd"),
clientTLS: &tls.Config{RootCAs: localhostPool},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": {
serverFunc: httpsServerInvalidHostname,
proxyServerFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: true},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false, // works because the test proxy ignores TLS errors
},
"proxied https with auth (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": {
serverFunc: httpsServerInvalidHostname,
proxyServerFunc: httpsServerInvalidHostname,
proxyAuth: url.UserPassword("proxyuser", "proxypasswd"),
clientTLS: &tls.Config{InsecureSkipVerify: true},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false, // works because the test proxy ignores TLS errors
},
"proxied https (invalid hostname + hostname verification) -> https (invalid hostname)": {
serverFunc: httpsServerInvalidHostname,
proxyServerFunc: httpsServerInvalidHostname,
clientTLS: &tls.Config{InsecureSkipVerify: false},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: true, // fails because the client doesn't trust the proxy
},
"proxied https (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": {
serverFunc: httpsServerValidHostname,
proxyServerFunc: httpsServerValidHostname,
clientTLS: &tls.Config{RootCAs: localhostPool},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
"proxied https with auth (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": {
serverFunc: httpsServerValidHostname,
proxyServerFunc: httpsServerValidHostname,
proxyAuth: url.UserPassword("proxyuser", "proxypasswd"),
clientTLS: &tls.Config{RootCAs: localhostPool},
serverConnectionHeader: "Upgrade",
serverUpgradeHeader: "SPDY/3.1",
serverStatusCode: http.StatusSwitchingProtocols,
shouldError: false,
},
}
for k, testCase := range testCases {
server := testCase.serverFunc(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if testCase.shouldError {
if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a {
t.Fatalf("%s: Expected connection=upgrade header, got '%s", k, a)
}
w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader)
w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader)
w.WriteHeader(testCase.serverStatusCode)
return
}
streamCh := make(chan httpstream.Stream)
responseUpgrader := NewResponseUpgrader()
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- s
return nil
})
if spdyConn == nil {
t.Fatalf("%s: unexpected nil spdyConn", k)
}
defer spdyConn.Close()
stream := <-streamCh
io.Copy(stream, stream)
}))
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("%s: Error creating request: %s", k, err)
}
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("%s: Error creating request: %s", k, err)
}
spdyTransport := NewSpdyRoundTripper(testCase.clientTLS, redirect, redirect)
var proxierCalled bool
var proxyCalledWithHost string
var proxyCalledWithAuth bool
var proxyCalledWithAuthHeader string
if testCase.proxyServerFunc != nil {
proxyHandler := goproxy.NewProxyHttpServer()
proxyHandler.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
proxyCalledWithHost = host
proxyAuthHeaderName := "Proxy-Authorization"
_, proxyCalledWithAuth = ctx.Req.Header[proxyAuthHeaderName]
proxyCalledWithAuthHeader = ctx.Req.Header.Get(proxyAuthHeaderName)
return goproxy.OkConnect, host
})
proxy := testCase.proxyServerFunc(proxyHandler)
spdyTransport.proxier = func(proxierReq *http.Request) (*url.URL, error) {
proxierCalled = true
proxyURL, err := url.Parse(proxy.URL)
if err != nil {
return nil, err
}
proxyURL.User = testCase.proxyAuth
return proxyURL, nil
}
defer proxy.Close()
}
client := &http.Client{Transport: spdyTransport}
resp, err := client.Do(req)
var conn httpstream.Connection
if err == nil {
conn, err = spdyTransport.NewConnection(resp)
}
haveErr := err != nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%s: shouldError=%t, got %t: %v", k, e, a, err)
}
if testCase.shouldError {
continue
}
defer conn.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("%s: expected http 101 switching protocols, got %d", k, resp.StatusCode)
}
stream, err := conn.CreateStream(http.Header{})
if err != nil {
t.Fatalf("%s: error creating client stream: %s", k, err)
}
n, err := stream.Write([]byte("hello"))
if err != nil {
t.Fatalf("%s: error writing to stream: %s", k, err)
}
if n != 5 {
t.Fatalf("%s: Expected to write 5 bytes, but actually wrote %d", k, n)
}
b := make([]byte, 5)
n, err = stream.Read(b)
if err != nil {
t.Fatalf("%s: error reading from stream: %s", k, err)
}
if n != 5 {
t.Fatalf("%s: Expected to read 5 bytes, but actually read %d", k, n)
}
if e, a := "hello", string(b[0:n]); e != a {
t.Fatalf("%s: expected '%s', got '%s'", k, e, a)
}
if testCase.proxyServerFunc != nil {
if !proxierCalled {
t.Fatalf("%s: Expected to use a proxy but proxier in SpdyRoundTripper wasn't called", k)
}
if proxyCalledWithHost != serverURL.Host {
t.Fatalf("%s: Expected to see a call to the proxy for backend %q, got %q", k, serverURL.Host, proxyCalledWithHost)
}
}
var expectedProxyAuth string
if testCase.proxyAuth != nil {
encodedCredentials := base64.StdEncoding.EncodeToString([]byte(testCase.proxyAuth.String()))
expectedProxyAuth = "Basic " + encodedCredentials
}
if len(expectedProxyAuth) == 0 && proxyCalledWithAuth {
t.Fatalf("%s: Proxy authorization unexpected, got %q", k, proxyCalledWithAuthHeader)
}
if proxyCalledWithAuthHeader != expectedProxyAuth {
t.Fatalf("%s: Expected to see a call to the proxy with credentials %q, got %q", k, testCase.proxyAuth, proxyCalledWithAuthHeader)
}
}
})
}
}
func TestRoundTripRedirects(t *testing.T) {
tests := []struct {
redirects int32
expectSuccess bool
}{
{0, true},
{1, true},
{9, true},
{10, false},
}
for _, test := range tests {
t.Run(fmt.Sprintf("with %d redirects", test.redirects), func(t *testing.T) {
var redirects int32 = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if redirects < test.redirects {
atomic.AddInt32(&redirects, 1)
http.Redirect(w, req, "redirect", http.StatusFound)
return
}
streamCh := make(chan httpstream.Stream)
responseUpgrader := NewResponseUpgrader()
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- s
return nil
})
if spdyConn == nil {
t.Fatalf("unexpected nil spdyConn")
}
defer spdyConn.Close()
stream := <-streamCh
io.Copy(stream, stream)
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("Error creating request: %s", err)
}
spdyTransport := NewSpdyRoundTripper(nil, true, true)
client := &http.Client{Transport: spdyTransport}
resp, err := client.Do(req)
if test.expectSuccess {
if err != nil {
t.Fatalf("error calling Do: %v", err)
}
} else {
if err == nil {
t.Fatalf("expecting an error")
} else if !strings.Contains(err.Error(), "too many redirects") {
t.Fatalf("expecting too many redirects, got %v", err)
}
return
}
conn, err := spdyTransport.NewConnection(resp)
if err != nil {
t.Fatalf("error calling NewConnection: %v", err)
}
defer conn.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("expected http 101 switching protocols, got %d", resp.StatusCode)
}
stream, err := conn.CreateStream(http.Header{})
if err != nil {
t.Fatalf("error creating client stream: %s", err)
}
n, err := stream.Write([]byte("hello"))
if err != nil {
t.Fatalf("error writing to stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to write 5 bytes, but actually wrote %d", n)
}
b := make([]byte, 5)
n, err = stream.Read(b)
if err != nil {
t.Fatalf("error reading from stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to read 5 bytes, but actually read %d", n)
}
if e, a := "hello", string(b[0:n]); e != a {
t.Fatalf("expected '%s', got '%s'", e, a)
}
})
}
}
// exampleCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 512 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var exampleCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBdzCCASGgAwIBAgIRAOVTAdPnfbS5V85mfS90TfIwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzBcMA0GCSqGSIb3DQEBAQUAA0sAMEgC
QQCoVSqeu8TBvF+70T7Jm4340YQNhds6IxjRoifenYodAO1dnKGrcbF266DJGunh
nIjQH7B12tduhl0fLK4Ezf7/AgMBAAGjUDBOMA4GA1UdDwEB/wQEAwICpDATBgNV
HSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBYGA1UdEQQPMA2CC2V4
YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA0EAk1kVa5uZ/AzwYDVcS9bpM/czwjjV
xq3VeSCfmNa2uNjbFvodmCRwZOHUvipAMGCUCV6j5vMrJ8eMj8tCQ36W9A==
-----END CERTIFICATE-----`)
var exampleKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAKhVKp67xMG8X7vRPsmbjfjRhA2F2zojGNGiJ96dih0A7V2coatx
sXbroMka6eGciNAfsHXa126GXR8srgTN/v8CAwEAAQJASdzdD7vKsUwMIejGCUb1
fAnLTPfAY3lFCa+CmR89nE22dAoRDv+5RbnBsZ58BazPNJHrsVPRlfXB3OQmSQr0
SQIhANoJhs+xOJE/i8nJv0uAbzKyiD1YkvRkta0GpUOULyAVAiEAxaQus3E/SuqD
P7y5NeJnE7X6XkyC35zrsJRkz7orE8MCIHdDjsI8pjyNDeGqwUCDWE/a6DrmIDwe
emHSqMN2YvChAiEAnxLCM9NWaenOsaIoP+J1rDuvw+4499nJKVqGuVrSCRkCIEqK
4KSchPMc3x8M/uhw9oWTtKFmjA/PPh0FsWCdKrEy
-----END RSA PRIVATE KEY-----`)
// localhostCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBjzCCATmgAwIBAgIRAKpi2WmTcFrVjxrl5n5YDUEwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzBcMA0GCSqGSIb3DQEBAQUAA0sAMEgC
QQC9fEbRszP3t14Gr4oahV7zFObBI4TfA5i7YnlMXeLinb7MnvT4bkfOJzE6zktn
59zP7UiHs3l4YOuqrjiwM413AgMBAAGjaDBmMA4GA1UdDwEB/wQEAwICpDATBgNV
HSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MC4GA1UdEQQnMCWCC2V4
YW1wbGUuY29thwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEBCwUA
A0EAUsVE6KMnza/ZbodLlyeMzdo7EM/5nb5ywyOxgIOCf0OOLHsPS9ueGLQX9HEG
//yjTXuhNcUugExIjM/AIwAZPQ==
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBAL18RtGzM/e3XgavihqFXvMU5sEjhN8DmLtieUxd4uKdvsye9Phu
R84nMTrOS2fn3M/tSIezeXhg66quOLAzjXcCAwEAAQJBAKcRxH9wuglYLBdI/0OT
BLzfWPZCEw1vZmMR2FF1Fm8nkNOVDPleeVGTWoOEcYYlQbpTmkGSxJ6ya+hqRi6x
goECIQDx3+X49fwpL6B5qpJIJMyZBSCuMhH4B7JevhGGFENi3wIhAMiNJN5Q3UkL
IuSvv03kaPR5XVQ99/UeEetUgGvBcABpAiBJSBzVITIVCGkGc7d+RCf49KTCIklv
bGWObufAR8Ni4QIgWpILjW8dkGg8GOUZ0zaNA6Nvt6TIv2UWGJ4v5PoV98kCIQDx
rIiZs5QbKdycsv9gQJzwQAogC8o04X3Zz3dsoX+h4A==
-----END RSA PRIVATE KEY-----`)

View File

@ -1,107 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync/atomic"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime"
)
const HeaderSpdy31 = "SPDY/3.1"
// responseUpgrader knows how to upgrade HTTP responses. It
// implements the httpstream.ResponseUpgrader interface.
type responseUpgrader struct {
}
// connWrapper is used to wrap a hijacked connection and its bufio.Reader. All
// calls will be handled directly by the underlying net.Conn with the exception
// of Read and Close calls, which will consider data in the bufio.Reader. This
// ensures that data already inside the used bufio.Reader instance is also
// read.
type connWrapper struct {
net.Conn
closed int32
bufReader *bufio.Reader
}
func (w *connWrapper) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&w.closed) == 1 {
return 0, io.EOF
}
return w.bufReader.Read(b)
}
func (w *connWrapper) Close() error {
err := w.Conn.Close()
atomic.StoreInt32(&w.closed, 1)
return err
}
// NewResponseUpgrader returns a new httpstream.ResponseUpgrader that is
// capable of upgrading HTTP responses using SPDY/3.1 via the
// spdystream package.
func NewResponseUpgrader() httpstream.ResponseUpgrader {
return responseUpgrader{}
}
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
errorMsg := fmt.Sprintf("unable to upgrade: missing upgrade headers in request: %#v", req.Header)
http.Error(w, errorMsg, http.StatusBadRequest)
return nil
}
hijacker, ok := w.(http.Hijacker)
if !ok {
errorMsg := fmt.Sprintf("unable to upgrade: unable to hijack response")
http.Error(w, errorMsg, http.StatusInternalServerError)
return nil
}
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
w.Header().Add(httpstream.HeaderUpgrade, HeaderSpdy31)
w.WriteHeader(http.StatusSwitchingProtocols)
conn, bufrw, err := hijacker.Hijack()
if err != nil {
runtime.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err))
return nil
}
connWithBuf := &connWrapper{Conn: conn, bufReader: bufrw.Reader}
spdyConn, err := NewServerConnection(connWithBuf, newStreamHandler)
if err != nil {
runtime.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err))
return nil
}
return spdyConn
}

View File

@ -1,93 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spdy
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestUpgradeResponse(t *testing.T) {
testCases := []struct {
connectionHeader string
upgradeHeader string
shouldError bool
}{
{
connectionHeader: "",
upgradeHeader: "",
shouldError: true,
},
{
connectionHeader: "Upgrade",
upgradeHeader: "",
shouldError: true,
},
{
connectionHeader: "",
upgradeHeader: "SPDY/3.1",
shouldError: true,
},
{
connectionHeader: "Upgrade",
upgradeHeader: "SPDY/3.1",
shouldError: false,
},
}
for i, testCase := range testCases {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
upgrader := NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, nil)
haveErr := conn == nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
}
if haveErr {
return
}
if conn == nil {
t.Fatalf("%d: unexpected nil conn", i)
}
defer conn.Close()
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("%d: error creating request: %s", i, err)
}
req.Header.Set("Connection", testCase.connectionHeader)
req.Header.Set("Upgrade", testCase.upgradeHeader)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("%d: unexpected non-nil err from client.Do: %s", i, err)
}
if testCase.shouldError {
continue
}
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Fatalf("%d: expected status 101 switching protocols, got %d", i, resp.StatusCode)
}
}
}

View File

@ -1,48 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package initialization
import (
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// IsInitialized returns if the initializers indicates means initialized.
func IsInitialized(initializers *metav1.Initializers) bool {
if initializers == nil {
return true
}
// Persisted objects will never be in this state. The initializer admission
// plugin will override metadata.initializers to nil. If the initializer
// admissio plugin is disabled, the generic registry always set
// metadata.initializers to nil. However, this function
// might be called before the object persisted, thus the check.
if len(initializers.Pending) == 0 && initializers.Result == nil {
return true
}
return false
}
// IsObjectInitialized returns if the object is initialized.
func IsObjectInitialized(obj runtime.Object) (bool, error) {
accessor, err := meta.Accessor(obj)
if err != nil {
return false, err
}
return IsInitialized(accessor.GetInitializers()), nil
}

View File

@ -1,43 +0,0 @@
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// This file was autogenerated by go-to-protobuf. Do not edit it manually!
syntax = 'proto2';
package k8s.io.apimachinery.pkg.util.intstr;
// Package-wide variables from generator "generated".
option go_package = "intstr";
// IntOrString is a type that can hold an int32 or a string. When used in
// JSON or YAML marshalling and unmarshalling, it produces or consumes the
// inner type. This allows you to have, for example, a JSON field that can
// accept a name or number.
// TODO: Rename to Int32OrString
//
// +protobuf=true
// +protobuf.options.(gogoproto.goproto_stringer)=false
// +k8s:openapi-gen=true
message IntOrString {
optional int64 type = 1;
optional int32 intVal = 2;
optional string strVal = 3;
}

View File

@ -1,183 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package intstr
import (
"encoding/json"
"reflect"
"testing"
"sigs.k8s.io/yaml"
)
func TestFromInt(t *testing.T) {
i := FromInt(93)
if i.Type != Int || i.IntVal != 93 {
t.Errorf("Expected IntVal=93, got %+v", i)
}
}
func TestFromString(t *testing.T) {
i := FromString("76")
if i.Type != String || i.StrVal != "76" {
t.Errorf("Expected StrVal=\"76\", got %+v", i)
}
}
type IntOrStringHolder struct {
IOrS IntOrString `json:"val"`
}
func TestIntOrStringUnmarshalJSON(t *testing.T) {
cases := []struct {
input string
result IntOrString
}{
{"{\"val\": 123}", FromInt(123)},
{"{\"val\": \"123\"}", FromString("123")},
}
for _, c := range cases {
var result IntOrStringHolder
if err := json.Unmarshal([]byte(c.input), &result); err != nil {
t.Errorf("Failed to unmarshal input '%v': %v", c.input, err)
}
if result.IOrS != c.result {
t.Errorf("Failed to unmarshal input '%v': expected %+v, got %+v", c.input, c.result, result)
}
}
}
func TestIntOrStringMarshalJSON(t *testing.T) {
cases := []struct {
input IntOrString
result string
}{
{FromInt(123), "{\"val\":123}"},
{FromString("123"), "{\"val\":\"123\"}"},
}
for _, c := range cases {
input := IntOrStringHolder{c.input}
result, err := json.Marshal(&input)
if err != nil {
t.Errorf("Failed to marshal input '%v': %v", input, err)
}
if string(result) != c.result {
t.Errorf("Failed to marshal input '%v': expected: %+v, got %q", input, c.result, string(result))
}
}
}
func TestIntOrStringMarshalJSONUnmarshalYAML(t *testing.T) {
cases := []struct {
input IntOrString
}{
{FromInt(123)},
{FromString("123")},
}
for _, c := range cases {
input := IntOrStringHolder{c.input}
jsonMarshalled, err := json.Marshal(&input)
if err != nil {
t.Errorf("1: Failed to marshal input: '%v': %v", input, err)
}
var result IntOrStringHolder
err = yaml.Unmarshal(jsonMarshalled, &result)
if err != nil {
t.Errorf("2: Failed to unmarshal '%+v': %v", string(jsonMarshalled), err)
}
if !reflect.DeepEqual(input, result) {
t.Errorf("3: Failed to marshal input '%+v': got %+v", input, result)
}
}
}
func TestGetValueFromIntOrPercent(t *testing.T) {
tests := []struct {
input IntOrString
total int
roundUp bool
expectErr bool
expectVal int
}{
{
input: FromInt(123),
expectErr: false,
expectVal: 123,
},
{
input: FromString("90%"),
total: 100,
roundUp: true,
expectErr: false,
expectVal: 90,
},
{
input: FromString("90%"),
total: 95,
roundUp: true,
expectErr: false,
expectVal: 86,
},
{
input: FromString("90%"),
total: 95,
roundUp: false,
expectErr: false,
expectVal: 85,
},
{
input: FromString("%"),
expectErr: true,
},
{
input: FromString("90#"),
expectErr: true,
},
{
input: FromString("#%"),
expectErr: true,
},
}
for i, test := range tests {
t.Logf("test case %d", i)
value, err := GetValueFromIntOrPercent(&test.input, test.total, test.roundUp)
if test.expectErr && err == nil {
t.Errorf("expected error, but got none")
continue
}
if !test.expectErr && err != nil {
t.Errorf("unexpected err: %v", err)
continue
}
if test.expectVal != value {
t.Errorf("expected %v, but got %v", test.expectVal, value)
}
}
}
func TestGetValueFromIntOrPercentNil(t *testing.T) {
_, err := GetValueFromIntOrPercent(nil, 0, false)
if err == nil {
t.Errorf("expected error got none")
}
}

View File

@ -1,319 +0,0 @@
// +build go1.8
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package json
import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
"testing"
)
func TestEvaluateTypes(t *testing.T) {
testCases := []struct {
In string
Data interface{}
Out string
Err bool
}{
// Invalid syntaxes
{
In: `x`,
Err: true,
},
{
In: ``,
Err: true,
},
// Null
{
In: `null`,
Data: nil,
Out: `null`,
},
// Booleans
{
In: `true`,
Data: true,
Out: `true`,
},
{
In: `false`,
Data: false,
Out: `false`,
},
// Integers
{
In: `0`,
Data: int64(0),
Out: `0`,
},
{
In: `-0`,
Data: int64(-0),
Out: `0`,
},
{
In: `1`,
Data: int64(1),
Out: `1`,
},
{
In: `2147483647`,
Data: int64(math.MaxInt32),
Out: `2147483647`,
},
{
In: `-2147483648`,
Data: int64(math.MinInt32),
Out: `-2147483648`,
},
{
In: `9223372036854775807`,
Data: int64(math.MaxInt64),
Out: `9223372036854775807`,
},
{
In: `-9223372036854775808`,
Data: int64(math.MinInt64),
Out: `-9223372036854775808`,
},
// Int overflow
{
In: `9223372036854775808`, // MaxInt64 + 1
Data: float64(9223372036854775808),
Out: `9223372036854776000`,
},
{
In: `-9223372036854775809`, // MinInt64 - 1
Data: float64(math.MinInt64),
Out: `-9223372036854776000`,
},
// Floats
{
In: `0.0`,
Data: float64(0),
Out: `0`,
},
{
In: `-0.0`,
Data: float64(-0.0),
Out: `-0`,
},
{
In: `0.5`,
Data: float64(0.5),
Out: `0.5`,
},
{
In: `1e3`,
Data: float64(1e3),
Out: `1000`,
},
{
In: `1.5`,
Data: float64(1.5),
Out: `1.5`,
},
{
In: `-0.3`,
Data: float64(-.3),
Out: `-0.3`,
},
{
// Largest representable float32
In: `3.40282346638528859811704183484516925440e+38`,
Data: float64(math.MaxFloat32),
Out: strconv.FormatFloat(math.MaxFloat32, 'g', -1, 64),
},
{
// Smallest float32 without losing precision
In: `1.175494351e-38`,
Data: float64(1.175494351e-38),
Out: `1.175494351e-38`,
},
{
// float32 closest to zero
In: `1.401298464324817070923729583289916131280e-45`,
Data: float64(math.SmallestNonzeroFloat32),
Out: strconv.FormatFloat(math.SmallestNonzeroFloat32, 'g', -1, 64),
},
{
// Largest representable float64
In: `1.797693134862315708145274237317043567981e+308`,
Data: float64(math.MaxFloat64),
Out: strconv.FormatFloat(math.MaxFloat64, 'g', -1, 64),
},
{
// Closest to zero without losing precision
In: `2.2250738585072014e-308`,
Data: float64(2.2250738585072014e-308),
Out: `2.2250738585072014e-308`,
},
{
// float64 closest to zero
In: `4.940656458412465441765687928682213723651e-324`,
Data: float64(math.SmallestNonzeroFloat64),
Out: strconv.FormatFloat(math.SmallestNonzeroFloat64, 'g', -1, 64),
},
{
// math.MaxFloat64 + 2 overflow
In: `1.7976931348623159e+308`,
Err: true,
},
// Strings
{
In: `""`,
Data: string(""),
Out: `""`,
},
{
In: `"0"`,
Data: string("0"),
Out: `"0"`,
},
{
In: `"A"`,
Data: string("A"),
Out: `"A"`,
},
{
In: `"Iñtërnâtiônàlizætiøn"`,
Data: string("Iñtërnâtiônàlizætiøn"),
Out: `"Iñtërnâtiônàlizætiøn"`,
},
// Arrays
{
In: `[]`,
Data: []interface{}{},
Out: `[]`,
},
{
In: `[` + strings.Join([]string{
`null`,
`true`,
`false`,
`0`,
`9223372036854775807`,
`0.0`,
`0.5`,
`1.0`,
`1.797693134862315708145274237317043567981e+308`,
`"0"`,
`"A"`,
`"Iñtërnâtiônàlizætiøn"`,
`[null,true,1,1.0,1.5]`,
`{"boolkey":true,"floatkey":1.0,"intkey":1,"nullkey":null}`,
}, ",") + `]`,
Data: []interface{}{
nil,
true,
false,
int64(0),
int64(math.MaxInt64),
float64(0.0),
float64(0.5),
float64(1.0),
float64(math.MaxFloat64),
string("0"),
string("A"),
string("Iñtërnâtiônàlizætiøn"),
[]interface{}{nil, true, int64(1), float64(1.0), float64(1.5)},
map[string]interface{}{"nullkey": nil, "boolkey": true, "intkey": int64(1), "floatkey": float64(1.0)},
},
Out: `[` + strings.Join([]string{
`null`,
`true`,
`false`,
`0`,
`9223372036854775807`,
`0`,
`0.5`,
`1`,
strconv.FormatFloat(math.MaxFloat64, 'g', -1, 64),
`"0"`,
`"A"`,
`"Iñtërnâtiônàlizætiøn"`,
`[null,true,1,1,1.5]`,
`{"boolkey":true,"floatkey":1,"intkey":1,"nullkey":null}`, // gets alphabetized by Marshal
}, ",") + `]`,
},
// Maps
{
In: `{}`,
Data: map[string]interface{}{},
Out: `{}`,
},
{
In: `{"boolkey":true,"floatkey":1.0,"intkey":1,"nullkey":null}`,
Data: map[string]interface{}{"nullkey": nil, "boolkey": true, "intkey": int64(1), "floatkey": float64(1.0)},
Out: `{"boolkey":true,"floatkey":1,"intkey":1,"nullkey":null}`, // gets alphabetized by Marshal
},
}
for _, tc := range testCases {
inputJSON := fmt.Sprintf(`{"data":%s}`, tc.In)
expectedJSON := fmt.Sprintf(`{"data":%s}`, tc.Out)
m := map[string]interface{}{}
err := Unmarshal([]byte(inputJSON), &m)
if tc.Err && err != nil {
// Expected error
continue
}
if err != nil {
t.Errorf("%s: error decoding: %v", tc.In, err)
continue
}
if tc.Err {
t.Errorf("%s: expected error, got none", tc.In)
continue
}
data, ok := m["data"]
if !ok {
t.Errorf("%s: decoded object missing data key: %#v", tc.In, m)
continue
}
if !reflect.DeepEqual(tc.Data, data) {
t.Errorf("%s: expected\n\t%#v (%v), got\n\t%#v (%v)", tc.In, tc.Data, reflect.TypeOf(tc.Data), data, reflect.TypeOf(data))
continue
}
outputJSON, err := Marshal(m)
if err != nil {
t.Errorf("%s: error encoding: %v", tc.In, err)
continue
}
if expectedJSON != string(outputJSON) {
t.Errorf("%s: expected\n\t%s, got\n\t%s", tc.In, expectedJSON, string(outputJSON))
continue
}
}
}

View File

@ -1,160 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonmergepatch
import (
"fmt"
"reflect"
"github.com/evanphx/json-patch"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/mergepatch"
)
// Create a 3-way merge patch based-on JSON merge patch.
// Calculate addition-and-change patch between current and modified.
// Calculate deletion patch between original and modified.
func CreateThreeWayJSONMergePatch(original, modified, current []byte, fns ...mergepatch.PreconditionFunc) ([]byte, error) {
if len(original) == 0 {
original = []byte(`{}`)
}
if len(modified) == 0 {
modified = []byte(`{}`)
}
if len(current) == 0 {
current = []byte(`{}`)
}
addAndChangePatch, err := jsonpatch.CreateMergePatch(current, modified)
if err != nil {
return nil, err
}
// Only keep addition and changes
addAndChangePatch, addAndChangePatchObj, err := keepOrDeleteNullInJsonPatch(addAndChangePatch, false)
if err != nil {
return nil, err
}
deletePatch, err := jsonpatch.CreateMergePatch(original, modified)
if err != nil {
return nil, err
}
// Only keep deletion
deletePatch, deletePatchObj, err := keepOrDeleteNullInJsonPatch(deletePatch, true)
if err != nil {
return nil, err
}
hasConflicts, err := mergepatch.HasConflicts(addAndChangePatchObj, deletePatchObj)
if err != nil {
return nil, err
}
if hasConflicts {
return nil, mergepatch.NewErrConflict(mergepatch.ToYAMLOrError(addAndChangePatchObj), mergepatch.ToYAMLOrError(deletePatchObj))
}
patch, err := jsonpatch.MergePatch(deletePatch, addAndChangePatch)
if err != nil {
return nil, err
}
var patchMap map[string]interface{}
err = json.Unmarshal(patch, &patchMap)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal patch for precondition check: %s", patch)
}
meetPreconditions, err := meetPreconditions(patchMap, fns...)
if err != nil {
return nil, err
}
if !meetPreconditions {
return nil, mergepatch.NewErrPreconditionFailed(patchMap)
}
return patch, nil
}
// keepOrDeleteNullInJsonPatch takes a json-encoded byte array and a boolean.
// It returns a filtered object and its corresponding json-encoded byte array.
// It is a wrapper of func keepOrDeleteNullInObj
func keepOrDeleteNullInJsonPatch(patch []byte, keepNull bool) ([]byte, map[string]interface{}, error) {
var patchMap map[string]interface{}
err := json.Unmarshal(patch, &patchMap)
if err != nil {
return nil, nil, err
}
filteredMap, err := keepOrDeleteNullInObj(patchMap, keepNull)
if err != nil {
return nil, nil, err
}
o, err := json.Marshal(filteredMap)
return o, filteredMap, err
}
// keepOrDeleteNullInObj will keep only the null value and delete all the others,
// if keepNull is true. Otherwise, it will delete all the null value and keep the others.
func keepOrDeleteNullInObj(m map[string]interface{}, keepNull bool) (map[string]interface{}, error) {
filteredMap := make(map[string]interface{})
var err error
for key, val := range m {
switch {
case keepNull && val == nil:
filteredMap[key] = nil
case val != nil:
switch typedVal := val.(type) {
case map[string]interface{}:
// Explicitly-set empty maps are treated as values instead of empty patches
if len(typedVal) == 0 {
if !keepNull {
filteredMap[key] = typedVal
}
continue
}
var filteredSubMap map[string]interface{}
filteredSubMap, err = keepOrDeleteNullInObj(typedVal, keepNull)
if err != nil {
return nil, err
}
// If the returned filtered submap was empty, this is an empty patch for the entire subdict, so the key
// should not be set
if len(filteredSubMap) != 0 {
filteredMap[key] = filteredSubMap
}
case []interface{}, string, float64, bool, int64, nil:
// Lists are always replaced in Json, no need to check each entry in the list.
if !keepNull {
filteredMap[key] = val
}
default:
return nil, fmt.Errorf("unknown type: %v", reflect.TypeOf(typedVal))
}
}
}
return filteredMap, nil
}
func meetPreconditions(patchObj map[string]interface{}, fns ...mergepatch.PreconditionFunc) (bool, error) {
// Apply the preconditions to the patch, and return an error if any of them fail.
for _, fn := range fns {
if !fn(patchObj) {
return false, fmt.Errorf("precondition failed for: %v", patchObj)
}
}
return true, nil
}

View File

@ -1,696 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonmergepatch
import (
"fmt"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/evanphx/json-patch"
"k8s.io/apimachinery/pkg/util/json"
"sigs.k8s.io/yaml"
)
type FilterNullTestCases struct {
TestCases []FilterNullTestCase
}
type FilterNullTestCase struct {
Description string
OriginalObj map[string]interface{}
ExpectedWithNull map[string]interface{}
ExpectedWithoutNull map[string]interface{}
}
var filterNullTestCaseData = []byte(`
testCases:
- description: nil original
originalObj: {}
expectedWithNull: {}
expectedWithoutNull: {}
- description: simple map
originalObj:
nilKey: null
nonNilKey: foo
expectedWithNull:
nilKey: null
expectedWithoutNull:
nonNilKey: foo
- description: simple map with all nil values
originalObj:
nilKey1: null
nilKey2: null
expectedWithNull:
nilKey1: null
nilKey2: null
expectedWithoutNull: {}
- description: simple map with all non-nil values
originalObj:
nonNilKey1: foo
nonNilKey2: bar
expectedWithNull: {}
expectedWithoutNull:
nonNilKey1: foo
nonNilKey2: bar
- description: nested map
originalObj:
mapKey:
nilKey: null
nonNilKey: foo
expectedWithNull:
mapKey:
nilKey: null
expectedWithoutNull:
mapKey:
nonNilKey: foo
- description: nested map that all subkeys are nil
originalObj:
mapKey:
nilKey1: null
nilKey2: null
expectedWithNull:
mapKey:
nilKey1: null
nilKey2: null
expectedWithoutNull: {}
- description: nested map that all subkeys are non-nil
originalObj:
mapKey:
nonNilKey1: foo
nonNilKey2: bar
expectedWithNull: {}
expectedWithoutNull:
mapKey:
nonNilKey1: foo
nonNilKey2: bar
- description: explicitly empty map as value
originalObj:
mapKey: {}
expectedWithNull: {}
expectedWithoutNull:
mapKey: {}
- description: explicitly empty nested map
originalObj:
mapKey:
nonNilKey: {}
expectedWithNull: {}
expectedWithoutNull:
mapKey:
nonNilKey: {}
- description: multiple expliclty empty nested maps
originalObj:
mapKey:
nonNilKey1: {}
nonNilKey2: {}
expectedWithNull: {}
expectedWithoutNull:
mapKey:
nonNilKey1: {}
nonNilKey2: {}
- description: nested map with non-null value as empty map
originalObj:
mapKey:
nonNilKey: {}
nilKey: null
expectedWithNull:
mapKey:
nilKey: null
expectedWithoutNull:
mapKey:
nonNilKey: {}
- description: empty list
originalObj:
listKey: []
expectedWithNull: {}
expectedWithoutNull:
listKey: []
- description: list of primitives
originalObj:
listKey:
- 1
- 2
expectedWithNull: {}
expectedWithoutNull:
listKey:
- 1
- 2
- description: list of maps
originalObj:
listKey:
- k1: v1
- k2: null
- k3: v3
k4: null
expectedWithNull: {}
expectedWithoutNull:
listKey:
- k1: v1
- k2: null
- k3: v3
k4: null
- description: list of different types
originalObj:
listKey:
- k1: v1
- k2: null
- v3
expectedWithNull: {}
expectedWithoutNull:
listKey:
- k1: v1
- k2: null
- v3
`)
func TestKeepOrDeleteNullInObj(t *testing.T) {
tc := FilterNullTestCases{}
err := yaml.Unmarshal(filterNullTestCaseData, &tc)
if err != nil {
t.Fatalf("can't unmarshal test cases: %s\n", err)
}
for _, test := range tc.TestCases {
resultWithNull, err := keepOrDeleteNullInObj(test.OriginalObj, true)
if err != nil {
t.Errorf("Failed in test case %q when trying to keep null values: %s", test.Description, err)
}
if !reflect.DeepEqual(test.ExpectedWithNull, resultWithNull) {
t.Errorf("Failed in test case %q when trying to keep null values:\nexpected expectedWithNull:\n%+v\nbut got:\n%+v\n", test.Description, test.ExpectedWithNull, resultWithNull)
}
resultWithoutNull, err := keepOrDeleteNullInObj(test.OriginalObj, false)
if err != nil {
t.Errorf("Failed in test case %q when trying to keep non-null values: %s", test.Description, err)
}
if !reflect.DeepEqual(test.ExpectedWithoutNull, resultWithoutNull) {
t.Errorf("Failed in test case %q when trying to keep non-null values:\n expected expectedWithoutNull:\n%+v\nbut got:\n%+v\n", test.Description, test.ExpectedWithoutNull, resultWithoutNull)
}
}
}
type JSONMergePatchTestCases struct {
TestCases []JSONMergePatchTestCase
}
type JSONMergePatchTestCase struct {
Description string
JSONMergePatchTestCaseData
}
type JSONMergePatchTestCaseData struct {
// Original is the original object (last-applied config in annotation)
Original map[string]interface{}
// Modified is the modified object (new config we want)
Modified map[string]interface{}
// Current is the current object (live config in the server)
Current map[string]interface{}
// ThreeWay is the expected three-way merge patch
ThreeWay map[string]interface{}
// Result is the expected object after applying the three-way patch on current object.
Result map[string]interface{}
}
var createJSONMergePatchTestCaseData = []byte(`
testCases:
- description: nil original
modified:
name: 1
value: 1
current:
name: 1
other: a
threeWay:
value: 1
result:
name: 1
value: 1
other: a
- description: nil patch
original:
name: 1
modified:
name: 1
current:
name: 1
threeWay:
{}
result:
name: 1
- description: add field to map
original:
name: 1
modified:
name: 1
value: 1
current:
name: 1
other: a
threeWay:
value: 1
result:
name: 1
value: 1
other: a
- description: add field to map with conflict
original:
name: 1
modified:
name: 1
value: 1
current:
name: a
other: a
threeWay:
name: 1
value: 1
result:
name: 1
value: 1
other: a
- description: add field and delete field from map
original:
name: 1
modified:
value: 1
current:
name: 1
other: a
threeWay:
name: null
value: 1
result:
value: 1
other: a
- description: add field and delete field from map with conflict
original:
name: 1
modified:
value: 1
current:
name: a
other: a
threeWay:
name: null
value: 1
result:
value: 1
other: a
- description: delete field from nested map
original:
simpleMap:
key1: 1
key2: 1
modified:
simpleMap:
key1: 1
current:
simpleMap:
key1: 1
key2: 1
other: a
threeWay:
simpleMap:
key2: null
result:
simpleMap:
key1: 1
other: a
- description: delete field from nested map with conflict
original:
simpleMap:
key1: 1
key2: 1
modified:
simpleMap:
key1: 1
current:
simpleMap:
key1: a
key2: 1
other: a
threeWay:
simpleMap:
key1: 1
key2: null
result:
simpleMap:
key1: 1
other: a
- description: delete all fields from map
original:
name: 1
value: 1
modified: {}
current:
name: 1
value: 1
other: a
threeWay:
name: null
value: null
result:
other: a
- description: delete all fields from map with conflict
original:
name: 1
value: 1
modified: {}
current:
name: 1
value: a
other: a
threeWay:
name: null
value: null
result:
other: a
- description: add field and delete all fields from map
original:
name: 1
value: 1
modified:
other: a
current:
name: 1
value: 1
other: a
threeWay:
name: null
value: null
result:
other: a
- description: add field and delete all fields from map with conflict
original:
name: 1
value: 1
modified:
other: a
current:
name: 1
value: 1
other: b
threeWay:
name: null
value: null
other: a
result:
other: a
- description: replace list of scalars
original:
intList:
- 1
- 2
modified:
intList:
- 2
- 3
current:
intList:
- 1
- 2
threeWay:
intList:
- 2
- 3
result:
intList:
- 2
- 3
- description: replace list of scalars with conflict
original:
intList:
- 1
- 2
modified:
intList:
- 2
- 3
current:
intList:
- 1
- 4
threeWay:
intList:
- 2
- 3
result:
intList:
- 2
- 3
- description: patch with different scalar type
original:
foo: 1
modified:
foo: true
current:
foo: 1
bar: 2
threeWay:
foo: true
result:
foo: true
bar: 2
- description: patch from scalar to list
original:
foo: 0
modified:
foo:
- 1
- 2
current:
foo: 0
bar: 2
threeWay:
foo:
- 1
- 2
result:
foo:
- 1
- 2
bar: 2
- description: patch from list to scalar
original:
foo:
- 1
- 2
modified:
foo: 0
current:
foo:
- 1
- 2
bar: 2
threeWay:
foo: 0
result:
foo: 0
bar: 2
- description: patch from scalar to map
original:
foo: 0
modified:
foo:
baz: 1
current:
foo: 0
bar: 2
threeWay:
foo:
baz: 1
result:
foo:
baz: 1
bar: 2
- description: patch from map to scalar
original:
foo:
baz: 1
modified:
foo: 0
current:
foo:
baz: 1
bar: 2
threeWay:
foo: 0
result:
foo: 0
bar: 2
- description: patch from map to list
original:
foo:
baz: 1
modified:
foo:
- 1
- 2
current:
foo:
baz: 1
bar: 2
threeWay:
foo:
- 1
- 2
result:
foo:
- 1
- 2
bar: 2
- description: patch from list to map
original:
foo:
- 1
- 2
modified:
foo:
baz: 0
current:
foo:
- 1
- 2
bar: 2
threeWay:
foo:
baz: 0
result:
foo:
baz: 0
bar: 2
- description: patch with different nested types
original:
foo:
- a: true
- 2
- false
modified:
foo:
- 1
- false
- b: 1
current:
foo:
- a: true
- 2
- false
bar: 0
threeWay:
foo:
- 1
- false
- b: 1
result:
foo:
- 1
- false
- b: 1
bar: 0
`)
func TestCreateThreeWayJSONMergePatch(t *testing.T) {
tc := JSONMergePatchTestCases{}
err := yaml.Unmarshal(createJSONMergePatchTestCaseData, &tc)
if err != nil {
t.Errorf("can't unmarshal test cases: %s\n", err)
return
}
for _, c := range tc.TestCases {
testThreeWayPatch(t, c)
}
}
func testThreeWayPatch(t *testing.T, c JSONMergePatchTestCase) {
original, modified, current, expected, result := threeWayTestCaseToJSONOrFail(t, c)
actual, err := CreateThreeWayJSONMergePatch(original, modified, current)
if err != nil {
t.Fatalf("error: %s", err)
}
testPatchCreation(t, expected, actual, c.Description)
testPatchApplication(t, current, actual, result, c.Description)
}
func testPatchCreation(t *testing.T, expected, actual []byte, description string) {
if !reflect.DeepEqual(actual, expected) {
t.Errorf("error in test case: %s\nexpected patch:\n%s\ngot:\n%s\n",
description, jsonToYAMLOrError(expected), jsonToYAMLOrError(actual))
return
}
}
func testPatchApplication(t *testing.T, original, patch, expected []byte, description string) {
result, err := jsonpatch.MergePatch(original, patch)
if err != nil {
t.Errorf("error: %s\nin test case: %s\ncannot apply patch:\n%s\nto original:\n%s\n",
err, description, jsonToYAMLOrError(patch), jsonToYAMLOrError(original))
return
}
if !reflect.DeepEqual(result, expected) {
format := "error in test case: %s\npatch application failed:\noriginal:\n%s\npatch:\n%s\nexpected:\n%s\ngot:\n%s\n"
t.Errorf(format, description,
jsonToYAMLOrError(original), jsonToYAMLOrError(patch),
jsonToYAMLOrError(expected), jsonToYAMLOrError(result))
return
}
}
func threeWayTestCaseToJSONOrFail(t *testing.T, c JSONMergePatchTestCase) ([]byte, []byte, []byte, []byte, []byte) {
return testObjectToJSONOrFail(t, c.Original),
testObjectToJSONOrFail(t, c.Modified),
testObjectToJSONOrFail(t, c.Current),
testObjectToJSONOrFail(t, c.ThreeWay),
testObjectToJSONOrFail(t, c.Result)
}
func testObjectToJSONOrFail(t *testing.T, o map[string]interface{}) []byte {
if o == nil {
return nil
}
j, err := toJSON(o)
if err != nil {
t.Error(err)
}
return j
}
func jsonToYAMLOrError(j []byte) string {
y, err := jsonToYAML(j)
if err != nil {
return err.Error()
}
return string(y)
}
func toJSON(v interface{}) ([]byte, error) {
j, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("json marshal failed: %v\n%v\n", err, spew.Sdump(v))
}
return j, nil
}
func jsonToYAML(j []byte) ([]byte, error) {
y, err := yaml.JSONToYAML(j)
if err != nil {
return nil, fmt.Errorf("json to yaml failed: %v\n%v\n", err, j)
}
return y, nil
}

View File

@ -1,5 +0,0 @@
approvers:
- pwittrock
reviewers:
- mengqiy
- apelisse

View File

@ -1,102 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mergepatch
import (
"errors"
"fmt"
"reflect"
)
var (
ErrBadJSONDoc = errors.New("invalid JSON document")
ErrNoListOfLists = errors.New("lists of lists are not supported")
ErrBadPatchFormatForPrimitiveList = errors.New("invalid patch format of primitive list")
ErrBadPatchFormatForRetainKeys = errors.New("invalid patch format of retainKeys")
ErrBadPatchFormatForSetElementOrderList = errors.New("invalid patch format of setElementOrder list")
ErrPatchContentNotMatchRetainKeys = errors.New("patch content doesn't match retainKeys list")
ErrUnsupportedStrategicMergePatchFormat = errors.New("strategic merge patch format is not supported")
)
func ErrNoMergeKey(m map[string]interface{}, k string) error {
return fmt.Errorf("map: %v does not contain declared merge key: %s", m, k)
}
func ErrBadArgType(expected, actual interface{}) error {
return fmt.Errorf("expected a %s, but received a %s",
reflect.TypeOf(expected),
reflect.TypeOf(actual))
}
func ErrBadArgKind(expected, actual interface{}) error {
var expectedKindString, actualKindString string
if expected == nil {
expectedKindString = "nil"
} else {
expectedKindString = reflect.TypeOf(expected).Kind().String()
}
if actual == nil {
actualKindString = "nil"
} else {
actualKindString = reflect.TypeOf(actual).Kind().String()
}
return fmt.Errorf("expected a %s, but received a %s", expectedKindString, actualKindString)
}
func ErrBadPatchType(t interface{}, m map[string]interface{}) error {
return fmt.Errorf("unknown patch type: %s in map: %v", t, m)
}
// IsPreconditionFailed returns true if the provided error indicates
// a precondition failed.
func IsPreconditionFailed(err error) bool {
_, ok := err.(ErrPreconditionFailed)
return ok
}
type ErrPreconditionFailed struct {
message string
}
func NewErrPreconditionFailed(target map[string]interface{}) ErrPreconditionFailed {
s := fmt.Sprintf("precondition failed for: %v", target)
return ErrPreconditionFailed{s}
}
func (err ErrPreconditionFailed) Error() string {
return err.message
}
type ErrConflict struct {
message string
}
func NewErrConflict(patch, current string) ErrConflict {
s := fmt.Sprintf("patch:\n%s\nconflicts with changes made from original to current:\n%s\n", patch, current)
return ErrConflict{s}
}
func (err ErrConflict) Error() string {
return err.message
}
// IsConflict returns true if the provided error indicates
// a conflict between the patch and the current configuration.
func IsConflict(err error) bool {
_, ok := err.(ErrConflict)
return ok
}

View File

@ -1,133 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mergepatch
import (
"fmt"
"reflect"
"github.com/davecgh/go-spew/spew"
"sigs.k8s.io/yaml"
)
// PreconditionFunc asserts that an incompatible change is not present within a patch.
type PreconditionFunc func(interface{}) bool
// RequireKeyUnchanged returns a precondition function that fails if the provided key
// is present in the patch (indicating that its value has changed).
func RequireKeyUnchanged(key string) PreconditionFunc {
return func(patch interface{}) bool {
patchMap, ok := patch.(map[string]interface{})
if !ok {
return true
}
// The presence of key means that its value has been changed, so the test fails.
_, ok = patchMap[key]
return !ok
}
}
// RequireMetadataKeyUnchanged creates a precondition function that fails
// if the metadata.key is present in the patch (indicating its value
// has changed).
func RequireMetadataKeyUnchanged(key string) PreconditionFunc {
return func(patch interface{}) bool {
patchMap, ok := patch.(map[string]interface{})
if !ok {
return true
}
patchMap1, ok := patchMap["metadata"]
if !ok {
return true
}
patchMap2, ok := patchMap1.(map[string]interface{})
if !ok {
return true
}
_, ok = patchMap2[key]
return !ok
}
}
func ToYAMLOrError(v interface{}) string {
y, err := toYAML(v)
if err != nil {
return err.Error()
}
return y
}
func toYAML(v interface{}) (string, error) {
y, err := yaml.Marshal(v)
if err != nil {
return "", fmt.Errorf("yaml marshal failed:%v\n%v\n", err, spew.Sdump(v))
}
return string(y), nil
}
// HasConflicts returns true if the left and right JSON interface objects overlap with
// different values in any key. All keys are required to be strings. Since patches of the
// same Type have congruent keys, this is valid for multiple patch types. This method
// supports JSON merge patch semantics.
//
// NOTE: Numbers with different types (e.g. int(0) vs int64(0)) will be detected as conflicts.
// Make sure the unmarshaling of left and right are consistent (e.g. use the same library).
func HasConflicts(left, right interface{}) (bool, error) {
switch typedLeft := left.(type) {
case map[string]interface{}:
switch typedRight := right.(type) {
case map[string]interface{}:
for key, leftValue := range typedLeft {
rightValue, ok := typedRight[key]
if !ok {
continue
}
if conflict, err := HasConflicts(leftValue, rightValue); err != nil || conflict {
return conflict, err
}
}
return false, nil
default:
return true, nil
}
case []interface{}:
switch typedRight := right.(type) {
case []interface{}:
if len(typedLeft) != len(typedRight) {
return true, nil
}
for i := range typedLeft {
if conflict, err := HasConflicts(typedLeft[i], typedRight[i]); err != nil || conflict {
return conflict, err
}
}
return false, nil
default:
return true, nil
}
case string, float64, bool, int64, nil:
return !reflect.DeepEqual(left, right), nil
default:
return true, fmt.Errorf("unknown type: %v", reflect.TypeOf(left))
}
}

View File

@ -1,136 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mergepatch
import (
"fmt"
"testing"
)
func TestHasConflicts(t *testing.T) {
testCases := []struct {
A interface{}
B interface{}
Ret bool
}{
{A: "hello", B: "hello", Ret: false},
{A: "hello", B: "hell", Ret: true},
{A: "hello", B: nil, Ret: true},
{A: "hello", B: int64(1), Ret: true},
{A: "hello", B: float64(1.0), Ret: true},
{A: "hello", B: false, Ret: true},
{A: int64(1), B: int64(1), Ret: false},
{A: nil, B: nil, Ret: false},
{A: false, B: false, Ret: false},
{A: float64(3), B: float64(3), Ret: false},
{A: "hello", B: []interface{}{}, Ret: true},
{A: []interface{}{int64(1)}, B: []interface{}{}, Ret: true},
{A: []interface{}{}, B: []interface{}{}, Ret: false},
{A: []interface{}{int64(1)}, B: []interface{}{int64(1)}, Ret: false},
{A: map[string]interface{}{}, B: []interface{}{int64(1)}, Ret: true},
{A: map[string]interface{}{}, B: map[string]interface{}{"a": int64(1)}, Ret: false},
{A: map[string]interface{}{"a": int64(1)}, B: map[string]interface{}{"a": int64(1)}, Ret: false},
{A: map[string]interface{}{"a": int64(1)}, B: map[string]interface{}{"a": int64(2)}, Ret: true},
{A: map[string]interface{}{"a": int64(1)}, B: map[string]interface{}{"b": int64(2)}, Ret: false},
{
A: map[string]interface{}{"a": []interface{}{int64(1)}},
B: map[string]interface{}{"a": []interface{}{int64(1)}},
Ret: false,
},
{
A: map[string]interface{}{"a": []interface{}{int64(1)}},
B: map[string]interface{}{"a": []interface{}{}},
Ret: true,
},
{
A: map[string]interface{}{"a": []interface{}{int64(1)}},
B: map[string]interface{}{"a": int64(1)},
Ret: true,
},
// Maps and lists with multiple entries.
{
A: map[string]interface{}{"a": int64(1), "b": int64(2)},
B: map[string]interface{}{"a": int64(1), "b": int64(0)},
Ret: true,
},
{
A: map[string]interface{}{"a": int64(1), "b": int64(2)},
B: map[string]interface{}{"a": int64(1), "b": int64(2)},
Ret: false,
},
{
A: map[string]interface{}{"a": int64(1), "b": int64(2)},
B: map[string]interface{}{"a": int64(1), "b": int64(0), "c": int64(3)},
Ret: true,
},
{
A: map[string]interface{}{"a": int64(1), "b": int64(2)},
B: map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(3)},
Ret: false,
},
{
A: map[string]interface{}{"a": []interface{}{int64(1), int64(2)}},
B: map[string]interface{}{"a": []interface{}{int64(1), int64(0)}},
Ret: true,
},
{
A: map[string]interface{}{"a": []interface{}{int64(1), int64(2)}},
B: map[string]interface{}{"a": []interface{}{int64(1), int64(2)}},
Ret: false,
},
// Numeric types are not interchangeable.
// Callers are expected to ensure numeric types are consistent in 'left' and 'right'.
{A: int64(0), B: float64(0), Ret: true},
// Other types are not interchangeable.
{A: int64(0), B: "0", Ret: true},
{A: int64(0), B: nil, Ret: true},
{A: int64(0), B: false, Ret: true},
{A: "true", B: true, Ret: true},
{A: "null", B: nil, Ret: true},
}
for _, testCase := range testCases {
testStr := fmt.Sprintf("A = %#v, B = %#v", testCase.A, testCase.B)
// Run each test case multiple times if it passes because HasConflicts()
// uses map iteration, which returns keys in nondeterministic order.
for try := 0; try < 10; try++ {
out, err := HasConflicts(testCase.A, testCase.B)
if err != nil {
t.Errorf("%v: unexpected error: %v", testStr, err)
break
}
if out != testCase.Ret {
t.Errorf("%v: expected %t got %t", testStr, testCase.Ret, out)
break
}
out, err = HasConflicts(testCase.B, testCase.A)
if err != nil {
t.Errorf("%v: unexpected error: %v", testStr, err)
break
}
if out != testCase.Ret {
t.Errorf("%v: expected reversed %t got %t", testStr, testCase.Ret, out)
break
}
}
}
}

View File

@ -1,56 +0,0 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package naming
import "testing"
func TestGetNameFromCallsite(t *testing.T) {
tests := []struct {
name string
ignoredPackages []string
expected string
}{
{
name: "simple",
expected: "k8s.io/apimachinery/pkg/util/naming/from_stack_test.go:50",
},
{
name: "ignore-package",
ignoredPackages: []string{"k8s.io/apimachinery/pkg/util/naming"},
expected: "testing/testing.go:827",
},
{
name: "ignore-file",
ignoredPackages: []string{"k8s.io/apimachinery/pkg/util/naming/from_stack_test.go"},
expected: "testing/testing.go:827",
},
{
name: "ignore-multiple",
ignoredPackages: []string{"k8s.io/apimachinery/pkg/util/naming/from_stack_test.go", "testing/testing.go"},
expected: "????",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
actual := GetNameFromCallsite(tc.ignoredPackages...)
if tc.expected != actual {
t.Fatalf("expected %q, got %q", tc.expected, actual)
}
})
}
}

View File

@ -1,441 +0,0 @@
// +build go1.8
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package net
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/wait"
)
func TestGetClientIP(t *testing.T) {
ipString := "10.0.0.1"
ip := net.ParseIP(ipString)
invalidIPString := "invalidIPString"
testCases := []struct {
Request http.Request
ExpectedIP net.IP
}{
{
Request: http.Request{},
},
{
Request: http.Request{
Header: map[string][]string{
"X-Real-Ip": {ipString},
},
},
ExpectedIP: ip,
},
{
Request: http.Request{
Header: map[string][]string{
"X-Real-Ip": {invalidIPString},
},
},
},
{
Request: http.Request{
Header: map[string][]string{
"X-Forwarded-For": {ipString},
},
},
ExpectedIP: ip,
},
{
Request: http.Request{
Header: map[string][]string{
"X-Forwarded-For": {invalidIPString},
},
},
},
{
Request: http.Request{
Header: map[string][]string{
"X-Forwarded-For": {invalidIPString + "," + ipString},
},
},
ExpectedIP: ip,
},
{
Request: http.Request{
// RemoteAddr is in the form host:port
RemoteAddr: ipString + ":1234",
},
ExpectedIP: ip,
},
{
Request: http.Request{
RemoteAddr: invalidIPString,
},
},
{
Request: http.Request{
Header: map[string][]string{
"X-Forwarded-For": {invalidIPString},
},
// RemoteAddr is in the form host:port
RemoteAddr: ipString,
},
ExpectedIP: ip,
},
}
for i, test := range testCases {
if a, e := GetClientIP(&test.Request), test.ExpectedIP; reflect.DeepEqual(e, a) != true {
t.Fatalf("test case %d failed. expected: %v, actual: %v", i, e, a)
}
}
}
func TestAppendForwardedForHeader(t *testing.T) {
testCases := []struct {
addr, forwarded, expected string
}{
{"1.2.3.4:8000", "", "1.2.3.4"},
{"1.2.3.4:8000", "8.8.8.8", "8.8.8.8, 1.2.3.4"},
{"1.2.3.4:8000", "8.8.8.8, 1.2.3.4", "8.8.8.8, 1.2.3.4, 1.2.3.4"},
{"1.2.3.4:8000", "foo,bar", "foo,bar, 1.2.3.4"},
}
for i, test := range testCases {
req := &http.Request{
RemoteAddr: test.addr,
Header: make(http.Header),
}
if test.forwarded != "" {
req.Header.Set("X-Forwarded-For", test.forwarded)
}
AppendForwardedForHeader(req)
actual := req.Header.Get("X-Forwarded-For")
if actual != test.expected {
t.Errorf("[%d] Expected %q, Got %q", i, test.expected, actual)
}
}
}
func TestProxierWithNoProxyCIDR(t *testing.T) {
testCases := []struct {
name string
noProxy string
url string
expectedDelegated bool
}{
{
name: "no env",
url: "https://192.168.143.1/api",
expectedDelegated: true,
},
{
name: "no cidr",
noProxy: "192.168.63.1",
url: "https://192.168.143.1/api",
expectedDelegated: true,
},
{
name: "hostname",
noProxy: "192.168.63.0/24,192.168.143.0/24",
url: "https://my-hostname/api",
expectedDelegated: true,
},
{
name: "match second cidr",
noProxy: "192.168.63.0/24,192.168.143.0/24",
url: "https://192.168.143.1/api",
expectedDelegated: false,
},
{
name: "match second cidr with host:port",
noProxy: "192.168.63.0/24,192.168.143.0/24",
url: "https://192.168.143.1:8443/api",
expectedDelegated: false,
},
{
name: "IPv6 cidr",
noProxy: "2001:db8::/48",
url: "https://[2001:db8::1]/api",
expectedDelegated: false,
},
{
name: "IPv6+port cidr",
noProxy: "2001:db8::/48",
url: "https://[2001:db8::1]:8443/api",
expectedDelegated: false,
},
{
name: "IPv6, not matching cidr",
noProxy: "2001:db8::/48",
url: "https://[2001:db8:1::1]/api",
expectedDelegated: true,
},
{
name: "IPv6+port, not matching cidr",
noProxy: "2001:db8::/48",
url: "https://[2001:db8:1::1]:8443/api",
expectedDelegated: true,
},
}
for _, test := range testCases {
os.Setenv("NO_PROXY", test.noProxy)
actualDelegated := false
proxyFunc := NewProxierWithNoProxyCIDR(func(req *http.Request) (*url.URL, error) {
actualDelegated = true
return nil, nil
})
req, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Errorf("%s: unexpected err: %v", test.name, err)
continue
}
if _, err := proxyFunc(req); err != nil {
t.Errorf("%s: unexpected err: %v", test.name, err)
continue
}
if test.expectedDelegated != actualDelegated {
t.Errorf("%s: expected %v, got %v", test.name, test.expectedDelegated, actualDelegated)
continue
}
}
}
type fakeTLSClientConfigHolder struct {
called bool
}
func (f *fakeTLSClientConfigHolder) TLSClientConfig() *tls.Config {
f.called = true
return nil
}
func (f *fakeTLSClientConfigHolder) RoundTrip(*http.Request) (*http.Response, error) {
return nil, nil
}
func TestTLSClientConfigHolder(t *testing.T) {
rt := &fakeTLSClientConfigHolder{}
TLSClientConfig(rt)
if !rt.called {
t.Errorf("didn't find tls config")
}
}
func TestJoinPreservingTrailingSlash(t *testing.T) {
tests := []struct {
a string
b string
want string
}{
// All empty
{"", "", ""},
// Empty a
{"", "/", "/"},
{"", "foo", "foo"},
{"", "/foo", "/foo"},
{"", "/foo/", "/foo/"},
// Empty b
{"/", "", "/"},
{"foo", "", "foo"},
{"/foo", "", "/foo"},
{"/foo/", "", "/foo/"},
// Both populated
{"/", "/", "/"},
{"foo", "foo", "foo/foo"},
{"/foo", "/foo", "/foo/foo"},
{"/foo/", "/foo/", "/foo/foo/"},
}
for _, tt := range tests {
name := fmt.Sprintf("%q+%q=%q", tt.a, tt.b, tt.want)
t.Run(name, func(t *testing.T) {
if got := JoinPreservingTrailingSlash(tt.a, tt.b); got != tt.want {
t.Errorf("JoinPreservingTrailingSlash() = %v, want %v", got, tt.want)
}
})
}
}
func TestConnectWithRedirects(t *testing.T) {
tests := []struct {
desc string
redirects []string
method string // initial request method, empty == GET
expectError bool
expectedRedirects int
newPort bool // special case different port test
}{{
desc: "relative redirects allowed",
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "redirects to the same host are allowed",
redirects: []string{"http://HOST/ok"}, // HOST replaced with server address in test
expectedRedirects: 1,
}, {
desc: "POST redirects to GET",
method: http.MethodPost,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "PUT redirects to GET",
method: http.MethodPut,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "DELETE redirects to GET",
method: http.MethodDelete,
redirects: []string{"/ok"},
expectedRedirects: 1,
}, {
desc: "9 redirects are allowed",
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9"},
expectedRedirects: 9,
}, {
desc: "10 redirects are forbidden",
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9", "/10"},
expectError: true,
}, {
desc: "redirect to different host are prevented",
redirects: []string{"http://example.com/foo"},
expectedRedirects: 0,
}, {
desc: "multiple redirect to different host forbidden",
redirects: []string{"/1", "/2", "/3", "http://example.com/foo"},
expectedRedirects: 3,
}, {
desc: "redirect to different port is allowed",
redirects: []string{"http://HOST/foo"},
expectedRedirects: 1,
newPort: true,
}}
const resultString = "Test output"
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
redirectCount := 0
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Verify redirect request.
if redirectCount > 0 {
expectedURL, err := url.Parse(test.redirects[redirectCount-1])
require.NoError(t, err, "test URL error")
assert.Equal(t, req.URL.Path, expectedURL.Path, "unknown redirect path")
assert.Equal(t, http.MethodGet, req.Method, "redirects must always be GET")
}
if redirectCount < len(test.redirects) {
http.Redirect(w, req, test.redirects[redirectCount], http.StatusFound)
redirectCount++
} else if redirectCount == len(test.redirects) {
w.Write([]byte(resultString))
} else {
t.Errorf("unexpected number of redirects %d to %s", redirectCount, req.URL.String())
}
}))
defer s.Close()
u, err := url.Parse(s.URL)
require.NoError(t, err, "Error parsing server URL")
host := u.Host
// Special case new-port test with a secondary server.
if test.newPort {
s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(resultString))
}))
defer s2.Close()
u2, err := url.Parse(s2.URL)
require.NoError(t, err, "Error parsing secondary server URL")
// Sanity check: secondary server uses same hostname, different port.
require.Equal(t, u.Hostname(), u2.Hostname(), "sanity check: same hostname")
require.NotEqual(t, u.Port(), u2.Port(), "sanity check: different port")
// Redirect to the secondary server.
host = u2.Host
}
// Update redirect URLs with actual host.
for i := range test.redirects {
test.redirects[i] = strings.Replace(test.redirects[i], "HOST", host, 1)
}
method := test.method
if method == "" {
method = http.MethodGet
}
netdialer := &net.Dialer{
Timeout: wait.ForeverTestTimeout,
KeepAlive: wait.ForeverTestTimeout,
}
dialer := DialerFunc(func(req *http.Request) (net.Conn, error) {
conn, err := netdialer.Dial("tcp", req.URL.Host)
if err != nil {
return conn, err
}
if err = req.Write(conn); err != nil {
require.NoError(t, conn.Close())
return nil, fmt.Errorf("error sending request: %v", err)
}
return conn, err
})
conn, rawResponse, err := ConnectWithRedirects(method, u, http.Header{} /*body*/, nil, dialer, true)
if test.expectError {
require.Error(t, err, "expected request error")
return
}
require.NoError(t, err, "unexpected request error")
assert.NoError(t, conn.Close(), "error closing connection")
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawResponse)), nil)
require.NoError(t, err, "unexpected request error")
result, err := ioutil.ReadAll(resp.Body)
require.NoError(t, resp.Body.Close())
if test.expectedRedirects < len(test.redirects) {
// Expect the last redirect to be returned.
assert.Equal(t, http.StatusFound, resp.StatusCode, "Final response is not a redirect")
assert.Equal(t, test.redirects[len(test.redirects)-1], resp.Header.Get("Location"))
assert.NotEqual(t, resultString, string(result), "wrong content")
} else {
assert.Equal(t, resultString, string(result), "stream content does not match")
}
})
}
}

View File

@ -1,725 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package net
import (
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"testing"
)
const gatewayfirst = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth3 00000000 0100FE0A 0003 0 0 1024 00000000 0 0 0
eth3 0000FE0A 00000000 0001 0 0 0 0080FFFF 0 0 0
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const gatewaylast = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
eth3 0000FE0A 00000000 0001 0 0 0 0080FFFF 0 0 0
eth3 00000000 0100FE0A 0003 0 0 1024 00000000 0 0 0
`
const gatewaymiddle = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth3 0000FE0A 00000000 0001 0 0 0 0080FFFF 0 0 0
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
eth3 00000000 0100FE0A 0003 0 0 1024 00000000 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const noInternetConnection = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const nothing = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
`
const badDestination = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth3 00000000 0100FE0A 0003 0 0 1024 00000000 0 0 0
eth3 0000FE0AA1 00000000 0001 0 0 0 0080FFFF 0 0 0
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const badGateway = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth3 00000000 0100FE0AA1 0003 0 0 1024 00000000 0 0 0
eth3 0000FE0A 00000000 0001 0 0 0 0080FFFF 0 0 0
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const route_Invalidhex = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth3 00000000 0100FE0AA 0003 0 0 1024 00000000 0 0 0
eth3 0000FE0A 00000000 0001 0 0 0 0080FFFF 0 0 0
docker0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
virbr0 007AA8C0 00000000 0001 0 0 0 00FFFFFF 0 0 0
`
const v6gatewayfirst = `00000000000000000000000000000000 00 00000000000000000000000000000000 00 20010001000000000000000000000001 00000064 00000000 00000000 00000003 eth3
20010002000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000100 00000000 00000000 00000001 eth3
00000000000000000000000000000000 60 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00200200 lo
`
const v6gatewaylast = `20010002000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000100 00000000 00000000 00000001 eth3
00000000000000000000000000000000 60 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00200200 lo
00000000000000000000000000000000 00 00000000000000000000000000000000 00 20010001000000000000000000000001 00000064 00000000 00000000 00000003 eth3
`
const v6gatewaymiddle = `20010002000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000100 00000000 00000000 00000001 eth3
00000000000000000000000000000000 00 00000000000000000000000000000000 00 20010001000000000000000000000001 00000064 00000000 00000000 00000003 eth3
00000000000000000000000000000000 60 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00200200 lo
`
const v6noDefaultRoutes = `00000000000000000000000000000000 60 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00200200 lo
20010001000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00000001 docker0
20010002000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000100 00000000 00000000 00000001 eth3
fe800000000000000000000000000000 40 00000000000000000000000000000000 00 00000000000000000000000000000000 00000100 00000000 00000000 00000001 eth3
`
const v6nothing = ``
const v6badDestination = `2001000200000000 7a 00000000000000000000000000000000 00 00000000000000000000000000000000 00000400 00000000 00000000 00200200 lo
`
const v6badGateway = `00000000000000000000000000000000 00 00000000000000000000000000000000 00 200100010000000000000000000000000012 00000064 00000000 00000000 00000003 eth3
`
const v6route_Invalidhex = `000000000000000000000000000000000 00 00000000000000000000000000000000 00 fe80000000000000021fcafffea0ec00 00000064 00000000 00000000 00000003 enp1s0f0
`
const (
flagUp = net.FlagUp | net.FlagBroadcast | net.FlagMulticast
flagDown = net.FlagBroadcast | net.FlagMulticast
flagLoopback = net.FlagUp | net.FlagLoopback
flagP2P = net.FlagUp | net.FlagPointToPoint
)
func makeIntf(index int, name string, flags net.Flags) net.Interface {
mac := net.HardwareAddr{0, 0x32, 0x7d, 0x69, 0xf7, byte(0x30 + index)}
return net.Interface{
Index: index,
MTU: 1500,
Name: name,
HardwareAddr: mac,
Flags: flags}
}
var (
downIntf = makeIntf(1, "eth3", flagDown)
loopbackIntf = makeIntf(1, "lo", flagLoopback)
p2pIntf = makeIntf(1, "lo", flagP2P)
upIntf = makeIntf(1, "eth3", flagUp)
)
var (
ipv4Route = Route{Interface: "eth3", Destination: net.ParseIP("0.0.0.0"), Gateway: net.ParseIP("10.254.0.1"), Family: familyIPv4}
ipv6Route = Route{Interface: "eth3", Destination: net.ParseIP("::"), Gateway: net.ParseIP("2001:1::1"), Family: familyIPv6}
)
var (
noRoutes = []Route{}
routeV4 = []Route{ipv4Route}
routeV6 = []Route{ipv6Route}
bothRoutes = []Route{ipv4Route, ipv6Route}
)
func TestGetIPv4Routes(t *testing.T) {
testCases := []struct {
tcase string
route string
count int
expected *Route
errStrFrag string
}{
{"gatewayfirst", gatewayfirst, 1, &ipv4Route, ""},
{"gatewaymiddle", gatewaymiddle, 1, &ipv4Route, ""},
{"gatewaylast", gatewaylast, 1, &ipv4Route, ""},
{"no routes", nothing, 0, nil, ""},
{"badDestination", badDestination, 0, nil, "invalid IPv4"},
{"badGateway", badGateway, 0, nil, "invalid IPv4"},
{"route_Invalidhex", route_Invalidhex, 0, nil, "odd length hex string"},
{"no default routes", noInternetConnection, 0, nil, ""},
}
for _, tc := range testCases {
r := strings.NewReader(tc.route)
routes, err := getIPv4DefaultRoutes(r)
if err != nil {
if !strings.Contains(err.Error(), tc.errStrFrag) {
t.Errorf("case[%s]: Error string %q does not contain %q", tc.tcase, err, tc.errStrFrag)
}
} else if tc.errStrFrag != "" {
t.Errorf("case[%s]: Error %q expected, but not seen", tc.tcase, tc.errStrFrag)
} else {
if tc.count != len(routes) {
t.Errorf("case[%s]: expected %d routes, have %v", tc.tcase, tc.count, routes)
} else if tc.count == 1 {
if !tc.expected.Gateway.Equal(routes[0].Gateway) {
t.Errorf("case[%s]: expected %v, got %v .err : %v", tc.tcase, tc.expected, routes, err)
}
if !routes[0].Destination.Equal(net.IPv4zero) {
t.Errorf("case[%s}: destination is not for default route (not zero)", tc.tcase)
}
}
}
}
}
func TestGetIPv6Routes(t *testing.T) {
testCases := []struct {
tcase string
route string
count int
expected *Route
errStrFrag string
}{
{"v6 gatewayfirst", v6gatewayfirst, 1, &ipv6Route, ""},
{"v6 gatewaymiddle", v6gatewaymiddle, 1, &ipv6Route, ""},
{"v6 gatewaylast", v6gatewaylast, 1, &ipv6Route, ""},
{"v6 no routes", v6nothing, 0, nil, ""},
{"v6 badDestination", v6badDestination, 0, nil, "invalid IPv6"},
{"v6 badGateway", v6badGateway, 0, nil, "invalid IPv6"},
{"v6 route_Invalidhex", v6route_Invalidhex, 0, nil, "odd length hex string"},
{"v6 no default routes", v6noDefaultRoutes, 0, nil, ""},
}
for _, tc := range testCases {
r := strings.NewReader(tc.route)
routes, err := getIPv6DefaultRoutes(r)
if err != nil {
if !strings.Contains(err.Error(), tc.errStrFrag) {
t.Errorf("case[%s]: Error string %q does not contain %q", tc.tcase, err, tc.errStrFrag)
}
} else if tc.errStrFrag != "" {
t.Errorf("case[%s]: Error %q expected, but not seen", tc.tcase, tc.errStrFrag)
} else {
if tc.count != len(routes) {
t.Errorf("case[%s]: expected %d routes, have %v", tc.tcase, tc.count, routes)
} else if tc.count == 1 {
if !tc.expected.Gateway.Equal(routes[0].Gateway) {
t.Errorf("case[%s]: expected %v, got %v .err : %v", tc.tcase, tc.expected, routes, err)
}
if !routes[0].Destination.Equal(net.IPv6zero) {
t.Errorf("case[%s}: destination is not for default route (not zero)", tc.tcase)
}
}
}
}
}
func TestParseIP(t *testing.T) {
testCases := []struct {
tcase string
ip string
family AddressFamily
success bool
expected net.IP
}{
{"empty", "", familyIPv4, false, nil},
{"too short", "AA", familyIPv4, false, nil},
{"too long", "0011223344", familyIPv4, false, nil},
{"invalid", "invalid!", familyIPv4, false, nil},
{"zero", "00000000", familyIPv4, true, net.IP{0, 0, 0, 0}},
{"ffff", "FFFFFFFF", familyIPv4, true, net.IP{0xff, 0xff, 0xff, 0xff}},
{"valid v4", "12345678", familyIPv4, true, net.IP{120, 86, 52, 18}},
{"valid v6", "fe800000000000000000000000000000", familyIPv6, true, net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
{"v6 too short", "fe80000000000000021fcafffea0ec0", familyIPv6, false, nil},
{"v6 too long", "fe80000000000000021fcafffea0ec002", familyIPv6, false, nil},
}
for _, tc := range testCases {
ip, err := parseIP(tc.ip, tc.family)
if !ip.Equal(tc.expected) {
t.Errorf("case[%v]: expected %q, got %q . err : %v", tc.tcase, tc.expected, ip, err)
}
}
}
func TestIsInterfaceUp(t *testing.T) {
testCases := []struct {
tcase string
intf *net.Interface
expected bool
}{
{"up", &net.Interface{Index: 0, MTU: 0, Name: "eth3", HardwareAddr: nil, Flags: net.FlagUp}, true},
{"down", &net.Interface{Index: 0, MTU: 0, Name: "eth3", HardwareAddr: nil, Flags: 0}, false},
{"no interface", nil, false},
}
for _, tc := range testCases {
it := isInterfaceUp(tc.intf)
if it != tc.expected {
t.Errorf("case[%v]: expected %v, got %v .", tc.tcase, tc.expected, it)
}
}
}
type addrStruct struct{ val string }
func (a addrStruct) Network() string {
return a.val
}
func (a addrStruct) String() string {
return a.val
}
func TestFinalIP(t *testing.T) {
testCases := []struct {
tcase string
addr []net.Addr
family AddressFamily
expected net.IP
}{
{"no ipv4", []net.Addr{addrStruct{val: "2001::5/64"}}, familyIPv4, nil},
{"no ipv6", []net.Addr{addrStruct{val: "10.128.0.4/32"}}, familyIPv6, nil},
{"invalidV4CIDR", []net.Addr{addrStruct{val: "10.20.30.40.50/24"}}, familyIPv4, nil},
{"invalidV6CIDR", []net.Addr{addrStruct{val: "fe80::2f7:67fff:fe6e:2956/64"}}, familyIPv6, nil},
{"loopback", []net.Addr{addrStruct{val: "127.0.0.1/24"}}, familyIPv4, nil},
{"loopbackv6", []net.Addr{addrStruct{val: "::1/128"}}, familyIPv6, nil},
{"link local v4", []net.Addr{addrStruct{val: "169.254.1.10/16"}}, familyIPv4, nil},
{"link local v6", []net.Addr{addrStruct{val: "fe80::2f7:6fff:fe6e:2956/64"}}, familyIPv6, nil},
{"ip4", []net.Addr{addrStruct{val: "10.254.12.132/17"}}, familyIPv4, net.ParseIP("10.254.12.132")},
{"ip6", []net.Addr{addrStruct{val: "2001::5/64"}}, familyIPv6, net.ParseIP("2001::5")},
{"no addresses", []net.Addr{}, familyIPv4, nil},
}
for _, tc := range testCases {
ip, err := getMatchingGlobalIP(tc.addr, tc.family)
if !ip.Equal(tc.expected) {
t.Errorf("case[%v]: expected %v, got %v .err : %v", tc.tcase, tc.expected, ip, err)
}
}
}
func TestAddrs(t *testing.T) {
var nw networkInterfacer = validNetworkInterface{}
intf := net.Interface{Index: 0, MTU: 0, Name: "eth3", HardwareAddr: nil, Flags: 0}
addrs, err := nw.Addrs(&intf)
if err != nil {
t.Errorf("expected no error got : %v", err)
}
if len(addrs) != 2 {
t.Errorf("expected addrs: 2 got null")
}
}
// Has a valid IPv4 address (IPv6 is LLA)
type validNetworkInterface struct {
}
func (_ validNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ validNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{
addrStruct{val: "fe80::2f7:6fff:fe6e:2956/64"}, addrStruct{val: "10.254.71.145/17"}}
return ifat, nil
}
func (_ validNetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// Both IPv4 and IPv6 addresses (expecting IPv4 to be used)
type v4v6NetworkInterface struct {
}
func (_ v4v6NetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ v4v6NetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{
addrStruct{val: "2001::10/64"}, addrStruct{val: "10.254.71.145/17"}}
return ifat, nil
}
func (_ v4v6NetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// Interface with only IPv6 address
type ipv6NetworkInterface struct {
}
func (_ ipv6NetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ ipv6NetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{addrStruct{val: "2001::200/64"}}
return ifat, nil
}
func (_ ipv6NetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// Only with link local addresses
type networkInterfaceWithOnlyLinkLocals struct {
}
func (_ networkInterfaceWithOnlyLinkLocals) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ networkInterfaceWithOnlyLinkLocals) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{addrStruct{val: "169.254.162.166/16"}, addrStruct{val: "fe80::200/10"}}
return ifat, nil
}
func (_ networkInterfaceWithOnlyLinkLocals) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// Unable to get interface(s)
type failGettingNetworkInterface struct {
}
func (_ failGettingNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return nil, fmt.Errorf("unable get Interface")
}
func (_ failGettingNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
return nil, nil
}
func (_ failGettingNetworkInterface) Interfaces() ([]net.Interface, error) {
return nil, fmt.Errorf("mock failed getting all interfaces")
}
// No interfaces
type noNetworkInterface struct {
}
func (_ noNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return nil, fmt.Errorf("no such network interface")
}
func (_ noNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
return nil, nil
}
func (_ noNetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{}, nil
}
// Interface is down
type downNetworkInterface struct {
}
func (_ downNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &downIntf, nil
}
func (_ downNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{
addrStruct{val: "fe80::2f7:6fff:fe6e:2956/64"}, addrStruct{val: "10.254.71.145/17"}}
return ifat, nil
}
func (_ downNetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{downIntf}, nil
}
// Loopback interface
type loopbackNetworkInterface struct {
}
func (_ loopbackNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &loopbackIntf, nil
}
func (_ loopbackNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{
addrStruct{val: "::1/128"}, addrStruct{val: "127.0.0.1/8"}}
return ifat, nil
}
func (_ loopbackNetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{loopbackIntf}, nil
}
// Point to point interface
type p2pNetworkInterface struct {
}
func (_ p2pNetworkInterface) InterfaceByName(intfName string) (*net.Interface, error) {
return &p2pIntf, nil
}
func (_ p2pNetworkInterface) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{
addrStruct{val: "::1/128"}, addrStruct{val: "127.0.0.1/8"}}
return ifat, nil
}
func (_ p2pNetworkInterface) Interfaces() ([]net.Interface, error) {
return []net.Interface{p2pIntf}, nil
}
// Unable to get IP addresses for interface
type networkInterfaceFailGetAddrs struct {
}
func (_ networkInterfaceFailGetAddrs) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ networkInterfaceFailGetAddrs) Addrs(intf *net.Interface) ([]net.Addr, error) {
return nil, fmt.Errorf("unable to get Addrs")
}
func (_ networkInterfaceFailGetAddrs) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// No addresses for interface
type networkInterfaceWithNoAddrs struct {
}
func (_ networkInterfaceWithNoAddrs) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ networkInterfaceWithNoAddrs) Addrs(intf *net.Interface) ([]net.Addr, error) {
ifat := []net.Addr{}
return ifat, nil
}
func (_ networkInterfaceWithNoAddrs) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
// Invalid addresses for interface
type networkInterfaceWithInvalidAddr struct {
}
func (_ networkInterfaceWithInvalidAddr) InterfaceByName(intfName string) (*net.Interface, error) {
return &upIntf, nil
}
func (_ networkInterfaceWithInvalidAddr) Addrs(intf *net.Interface) ([]net.Addr, error) {
var ifat []net.Addr
ifat = []net.Addr{addrStruct{val: "10.20.30.40.50/24"}}
return ifat, nil
}
func (_ networkInterfaceWithInvalidAddr) Interfaces() ([]net.Interface, error) {
return []net.Interface{upIntf}, nil
}
func TestGetIPFromInterface(t *testing.T) {
testCases := []struct {
tcase string
nwname string
family AddressFamily
nw networkInterfacer
expected net.IP
errStrFrag string
}{
{"ipv4", "eth3", familyIPv4, validNetworkInterface{}, net.ParseIP("10.254.71.145"), ""},
{"ipv6", "eth3", familyIPv6, ipv6NetworkInterface{}, net.ParseIP("2001::200"), ""},
{"no ipv4", "eth3", familyIPv4, ipv6NetworkInterface{}, nil, ""},
{"no ipv6", "eth3", familyIPv6, validNetworkInterface{}, nil, ""},
{"I/F down", "eth3", familyIPv4, downNetworkInterface{}, nil, ""},
{"I/F get fail", "eth3", familyIPv4, noNetworkInterface{}, nil, "no such network interface"},
{"fail get addr", "eth3", familyIPv4, networkInterfaceFailGetAddrs{}, nil, "unable to get Addrs"},
{"bad addr", "eth3", familyIPv4, networkInterfaceWithInvalidAddr{}, nil, "invalid CIDR"},
}
for _, tc := range testCases {
ip, err := getIPFromInterface(tc.nwname, tc.family, tc.nw)
if err != nil {
if !strings.Contains(err.Error(), tc.errStrFrag) {
t.Errorf("case[%s]: Error string %q does not contain %q", tc.tcase, err, tc.errStrFrag)
}
} else if tc.errStrFrag != "" {
t.Errorf("case[%s]: Error %q expected, but not seen", tc.tcase, tc.errStrFrag)
} else if !ip.Equal(tc.expected) {
t.Errorf("case[%v]: expected %v, got %+v .err : %v", tc.tcase, tc.expected, ip, err)
}
}
}
func TestChooseHostInterfaceFromRoute(t *testing.T) {
testCases := []struct {
tcase string
routes []Route
nw networkInterfacer
expected net.IP
}{
{"ipv4", routeV4, validNetworkInterface{}, net.ParseIP("10.254.71.145")},
{"ipv6", routeV6, ipv6NetworkInterface{}, net.ParseIP("2001::200")},
{"prefer ipv4", bothRoutes, v4v6NetworkInterface{}, net.ParseIP("10.254.71.145")},
{"all LLA", routeV4, networkInterfaceWithOnlyLinkLocals{}, nil},
{"no routes", noRoutes, validNetworkInterface{}, nil},
{"fail get IP", routeV4, networkInterfaceFailGetAddrs{}, nil},
}
for _, tc := range testCases {
ip, err := chooseHostInterfaceFromRoute(tc.routes, tc.nw)
if !ip.Equal(tc.expected) {
t.Errorf("case[%v]: expected %v, got %+v .err : %v", tc.tcase, tc.expected, ip, err)
}
}
}
func TestMemberOf(t *testing.T) {
testCases := []struct {
tcase string
ip net.IP
family AddressFamily
expected bool
}{
{"ipv4 is 4", net.ParseIP("10.20.30.40"), familyIPv4, true},
{"ipv4 is 6", net.ParseIP("10.10.10.10"), familyIPv6, false},
{"ipv6 is 4", net.ParseIP("2001::100"), familyIPv4, false},
{"ipv6 is 6", net.ParseIP("2001::100"), familyIPv6, true},
}
for _, tc := range testCases {
if memberOf(tc.ip, tc.family) != tc.expected {
t.Errorf("case[%s]: expected %+v", tc.tcase, tc.expected)
}
}
}
func TestGetIPFromHostInterfaces(t *testing.T) {
testCases := []struct {
tcase string
nw networkInterfacer
expected net.IP
errStrFrag string
}{
{"fail get I/Fs", failGettingNetworkInterface{}, nil, "failed getting all interfaces"},
{"no interfaces", noNetworkInterface{}, nil, "no interfaces"},
{"I/F not up", downNetworkInterface{}, nil, "no acceptable"},
{"loopback only", loopbackNetworkInterface{}, nil, "no acceptable"},
{"P2P I/F only", p2pNetworkInterface{}, nil, "no acceptable"},
{"fail get addrs", networkInterfaceFailGetAddrs{}, nil, "unable to get Addrs"},
{"no addresses", networkInterfaceWithNoAddrs{}, nil, "no acceptable"},
{"invalid addr", networkInterfaceWithInvalidAddr{}, nil, "invalid CIDR"},
{"no matches", networkInterfaceWithOnlyLinkLocals{}, nil, "no acceptable"},
{"ipv4", validNetworkInterface{}, net.ParseIP("10.254.71.145"), ""},
{"ipv6", ipv6NetworkInterface{}, net.ParseIP("2001::200"), ""},
}
for _, tc := range testCases {
ip, err := chooseIPFromHostInterfaces(tc.nw)
if !ip.Equal(tc.expected) {
t.Errorf("case[%s]: expected %+v, got %+v with err : %v", tc.tcase, tc.expected, ip, err)
}
if err != nil && !strings.Contains(err.Error(), tc.errStrFrag) {
t.Errorf("case[%s]: unable to find %q in error string %q", tc.tcase, tc.errStrFrag, err.Error())
}
}
}
func makeRouteFile(content string, t *testing.T) (*os.File, error) {
routeFile, err := ioutil.TempFile("", "route")
if err != nil {
return nil, err
}
if _, err := routeFile.Write([]byte(content)); err != nil {
return routeFile, err
}
err = routeFile.Close()
return routeFile, err
}
func TestFailGettingIPv4Routes(t *testing.T) {
defer func() { v4File.name = ipv4RouteFile }()
// Try failure to open file (should not occur, as caller ensures we have IPv4 route file, but being thorough)
v4File.name = "no-such-file"
errStrFrag := "no such file"
_, err := v4File.extract()
if err == nil {
t.Errorf("Expected error trying to read non-existent v4 route file")
}
if !strings.Contains(err.Error(), errStrFrag) {
t.Errorf("Unable to find %q in error string %q", errStrFrag, err.Error())
}
}
func TestFailGettingIPv6Routes(t *testing.T) {
defer func() { v6File.name = ipv6RouteFile }()
// Try failure to open file (this would be ignored by caller)
v6File.name = "no-such-file"
errStrFrag := "no such file"
_, err := v6File.extract()
if err == nil {
t.Errorf("Expected error trying to read non-existent v6 route file")
}
if !strings.Contains(err.Error(), errStrFrag) {
t.Errorf("Unable to find %q in error string %q", errStrFrag, err.Error())
}
}
func TestGetAllDefaultRoutesFailNoV4RouteFile(t *testing.T) {
defer func() { v4File.name = ipv4RouteFile }()
// Should not occur, as caller ensures we have IPv4 route file, but being thorough
v4File.name = "no-such-file"
errStrFrag := "no such file"
_, err := getAllDefaultRoutes()
if err == nil {
t.Errorf("Expected error trying to read non-existent v4 route file")
}
if !strings.Contains(err.Error(), errStrFrag) {
t.Errorf("Unable to find %q in error string %q", errStrFrag, err.Error())
}
}
func TestGetAllDefaultRoutes(t *testing.T) {
testCases := []struct {
tcase string
v4Info string
v6Info string
count int
expected []Route
errStrFrag string
}{
{"no routes", noInternetConnection, v6noDefaultRoutes, 0, nil, "no default routes"},
{"only v4 route", gatewayfirst, v6noDefaultRoutes, 1, routeV4, ""},
{"only v6 route", noInternetConnection, v6gatewayfirst, 1, routeV6, ""},
{"v4 and v6 routes", gatewayfirst, v6gatewayfirst, 2, bothRoutes, ""},
}
defer func() {
v4File.name = ipv4RouteFile
v6File.name = ipv6RouteFile
}()
for _, tc := range testCases {
routeFile, err := makeRouteFile(tc.v4Info, t)
if routeFile != nil {
defer os.Remove(routeFile.Name())
}
if err != nil {
t.Errorf("case[%s]: test setup failure for IPv4 route file: %v", tc.tcase, err)
}
v4File.name = routeFile.Name()
v6routeFile, err := makeRouteFile(tc.v6Info, t)
if v6routeFile != nil {
defer os.Remove(v6routeFile.Name())
}
if err != nil {
t.Errorf("case[%s]: test setup failure for IPv6 route file: %v", tc.tcase, err)
}
v6File.name = v6routeFile.Name()
routes, err := getAllDefaultRoutes()
if err != nil {
if !strings.Contains(err.Error(), tc.errStrFrag) {
t.Errorf("case[%s]: Error string %q does not contain %q", tc.tcase, err, tc.errStrFrag)
}
} else if tc.errStrFrag != "" {
t.Errorf("case[%s]: Error %q expected, but not seen", tc.tcase, tc.errStrFrag)
} else {
if tc.count != len(routes) {
t.Errorf("case[%s]: expected %d routes, have %v", tc.tcase, tc.count, routes)
}
for i, expected := range tc.expected {
if !expected.Gateway.Equal(routes[i].Gateway) {
t.Errorf("case[%s]: at %d expected %v, got %v .err : %v", tc.tcase, i, tc.expected, routes, err)
}
zeroIP := net.IPv4zero
if expected.Family == familyIPv6 {
zeroIP = net.IPv6zero
}
if !routes[i].Destination.Equal(zeroIP) {
t.Errorf("case[%s}: at %d destination is not for default route (not %v)", tc.tcase, i, zeroIP)
}
}
}
}
}

View File

@ -1,77 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package net
import (
"testing"
flag "github.com/spf13/pflag"
)
func TestPortRange(t *testing.T) {
testCases := []struct {
input string
success bool
expected string
included int
excluded int
}{
{"100-200", true, "100-200", 200, 201},
{" 100-200 ", true, "100-200", 200, 201},
{"0-0", true, "0-0", 0, 1},
{"", true, "", -1, 0},
{"100", true, "100-100", 100, 101},
{"100 - 200", false, "", -1, -1},
{"-100", false, "", -1, -1},
{"100-", false, "", -1, -1},
{"200-100", false, "", -1, -1},
{"60000-70000", false, "", -1, -1},
{"70000-80000", false, "", -1, -1},
{"70000+80000", false, "", -1, -1},
{"1+0", true, "1-1", 1, 2},
{"0+0", true, "0-0", 0, 1},
{"1+-1", false, "", -1, -1},
{"1-+1", false, "", -1, -1},
{"100+200", true, "100-300", 300, 301},
{"1+65535", false, "", -1, -1},
{"0+65535", true, "0-65535", 65535, 65536},
}
for i := range testCases {
tc := &testCases[i]
pr := &PortRange{}
var f flag.Value = pr
err := f.Set(tc.input)
if err != nil && tc.success == true {
t.Errorf("expected success, got %q", err)
continue
} else if err == nil && tc.success == false {
t.Errorf("expected failure %#v", testCases[i])
continue
} else if tc.success {
if f.String() != tc.expected {
t.Errorf("expected %q, got %q", tc.expected, f.String())
}
if tc.included >= 0 && !pr.Contains(tc.included) {
t.Errorf("expected %q to include %d", f.String(), tc.included)
}
if tc.excluded >= 0 && pr.Contains(tc.excluded) {
t.Errorf("expected %q to exclude %d", f.String(), tc.excluded)
}
}
}
}

View File

@ -1,121 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package net
import (
"testing"
)
func TestSplitSchemeNamePort(t *testing.T) {
table := []struct {
in string
name, port, scheme string
valid bool
normalized bool
}{
{
in: "aoeu:asdf",
name: "aoeu",
port: "asdf",
valid: true,
normalized: true,
}, {
in: "http:aoeu:asdf",
scheme: "http",
name: "aoeu",
port: "asdf",
valid: true,
normalized: true,
}, {
in: "https:aoeu:",
scheme: "https",
name: "aoeu",
port: "",
valid: true,
normalized: false,
}, {
in: "https:aoeu:asdf",
scheme: "https",
name: "aoeu",
port: "asdf",
valid: true,
normalized: true,
}, {
in: "aoeu:",
name: "aoeu",
valid: true,
normalized: false,
}, {
in: "aoeu",
name: "aoeu",
valid: true,
normalized: true,
}, {
in: ":asdf",
valid: false,
}, {
in: "aoeu:asdf:htns",
valid: false,
}, {
in: "http::asdf",
valid: false,
}, {
in: "http::",
valid: false,
}, {
in: "",
valid: false,
},
}
for _, item := range table {
scheme, name, port, valid := SplitSchemeNamePort(item.in)
if e, a := item.scheme, scheme; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.name, name; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.port, port; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.valid, valid; e != a {
t.Errorf("%q: Wanted %t, got %t", item.in, e, a)
}
// Make sure valid items round trip through JoinSchemeNamePort
if item.valid {
out := JoinSchemeNamePort(scheme, name, port)
if item.normalized && out != item.in {
t.Errorf("%q: Wanted %s, got %s", item.in, item.in, out)
}
scheme, name, port, valid := SplitSchemeNamePort(out)
if e, a := item.scheme, scheme; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.name, name; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.port, port; e != a {
t.Errorf("%q: Wanted %q, got %q", item.in, e, a)
}
if e, a := item.valid, valid; e != a {
t.Errorf("%q: Wanted %t, got %t", item.in, e, a)
}
}
}
}

View File

@ -1,68 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package net
import (
"net"
"testing"
)
func getIPNet(cidr string) *net.IPNet {
_, ipnet, _ := net.ParseCIDR(cidr)
return ipnet
}
func TestIPNetEqual(t *testing.T) {
testCases := []struct {
ipnet1 *net.IPNet
ipnet2 *net.IPNet
expect bool
}{
//null case
{
getIPNet("10.0.0.1/24"),
getIPNet(""),
false,
},
{
getIPNet("10.0.0.0/24"),
getIPNet("10.0.0.0/24"),
true,
},
{
getIPNet("10.0.0.0/24"),
getIPNet("10.0.0.1/24"),
true,
},
{
getIPNet("10.0.0.0/25"),
getIPNet("10.0.0.0/24"),
false,
},
{
getIPNet("10.0.1.0/24"),
getIPNet("10.0.0.0/24"),
false,
},
}
for _, tc := range testCases {
if tc.expect != IPNetEqual(tc.ipnet1, tc.ipnet2) {
t.Errorf("Expect equality of %s and %s be to %v", tc.ipnet1.String(), tc.ipnet2.String(), tc.expect)
}
}
}

View File

@ -1,117 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"k8s.io/klog"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, err := utilnet.DialerFor(transport)
if err != nil {
klog.V(5).Infof("Unable to unwrap transport %T to get dialer: %v", transport, err)
}
switch url.Scheme {
case "http":
if dialer != nil {
return dialer(ctx, "tcp", dialAddr)
}
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
var tlsConn *tls.Conn
var err error
tlsConfig, err = utilnet.TLSClientConfig(transport)
if err != nil {
klog.V(5).Infof("Unable to unwrap transport %T to get at TLS config: %v", transport, err)
}
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer(ctx, "tcp", dialAddr)
if err != nil {
return nil, err
}
if tlsConfig == nil {
// tls.Client requires non-nil config
klog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
// tls.Handshake() requires ServerName or InsecureSkipVerify
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
// tls.Handshake() requires ServerName or InsecureSkipVerify
// infer the ServerName from the hostname we're connecting to.
inferredHost := dialAddr
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
inferredHost = host
}
// Make a copy to avoid polluting the provided config
tlsConfigCopy := tlsConfig.Clone()
tlsConfigCopy.ServerName = inferredHost
tlsConfig = tlsConfigCopy
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
return nil, err
}
} else {
// Dial. This Dial method does not allow to pass a context unfortunately
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
}
// Return if we were configured to skip validation
if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if tlsConfig != nil && len(tlsConfig.ServerName) > 0 {
host = tlsConfig.ServerName
}
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}

View File

@ -1,178 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"k8s.io/apimachinery/pkg/util/diff"
utilnet "k8s.io/apimachinery/pkg/util/net"
)
func TestDialURL(t *testing.T) {
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(localhostCert) {
t.Fatal("error setting up localhostCert pool")
}
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Fatal(err)
}
var d net.Dialer
testcases := map[string]struct {
TLSConfig *tls.Config
Dial utilnet.DialFunc
ExpectError string
}{
"insecure": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
},
"secure, no roots": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
ExpectError: "unknown authority",
},
"secure with roots": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
},
"secure with mismatched server": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
ExpectError: "not bogus.com",
},
"secure with matched server": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
},
"insecure, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Dial: d.DialContext,
},
"secure, no roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
Dial: d.DialContext,
ExpectError: "unknown authority",
},
"secure with roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
Dial: d.DialContext,
},
"secure with mismatched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
Dial: d.DialContext,
ExpectError: "not bogus.com",
},
"secure with matched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: d.DialContext,
},
}
for k, tc := range testcases {
func() {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
defer ts.Close()
ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
ts.StartTLS()
// Make a copy of the config
tlsConfigCopy := tc.TLSConfig.Clone()
// Clone() mutates the receiver (!), so also call it on the copy
tlsConfigCopy.Clone()
transport := &http.Transport{
DialContext: tc.Dial,
TLSClientConfig: tlsConfigCopy,
}
extractedDial, err := utilnet.DialerFor(transport)
if err != nil {
t.Fatal(err)
}
if fmt.Sprintf("%p", extractedDial) != fmt.Sprintf("%p", tc.Dial) {
t.Fatalf("%s: Unexpected dial", k)
}
extractedTLSConfig, err := utilnet.TLSClientConfig(transport)
if err != nil {
t.Fatal(err)
}
if extractedTLSConfig == nil {
t.Fatalf("%s: Expected tlsConfig", k)
}
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(context.Background(), u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {
t.Errorf("%s: transport's copy of TLSConfig was mutated\n%s", k, diff.ObjectReflectDiff(tc.TLSConfig, tlsConfigCopy))
}
if err != nil {
if tc.ExpectError == "" {
t.Errorf("%s: expected no error, got %q", k, err.Error())
}
if !strings.Contains(err.Error(), tc.ExpectError) {
t.Errorf("%s: expected error containing %q, got %q", k, tc.ExpectError, err.Error())
}
return
}
conn.Close()
if tc.ExpectError != "" {
t.Errorf("%s: expected error %q, got none", k, tc.ExpectError)
}
}()
}
}
// localhostCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBjzCCATmgAwIBAgIRAKpi2WmTcFrVjxrl5n5YDUEwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzBcMA0GCSqGSIb3DQEBAQUAA0sAMEgC
QQC9fEbRszP3t14Gr4oahV7zFObBI4TfA5i7YnlMXeLinb7MnvT4bkfOJzE6zktn
59zP7UiHs3l4YOuqrjiwM413AgMBAAGjaDBmMA4GA1UdDwEB/wQEAwICpDATBgNV
HSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MC4GA1UdEQQnMCWCC2V4
YW1wbGUuY29thwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEBCwUA
A0EAUsVE6KMnza/ZbodLlyeMzdo7EM/5nb5ywyOxgIOCf0OOLHsPS9ueGLQX9HEG
//yjTXuhNcUugExIjM/AIwAZPQ==
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBAL18RtGzM/e3XgavihqFXvMU5sEjhN8DmLtieUxd4uKdvsye9Phu
R84nMTrOS2fn3M/tSIezeXhg66quOLAzjXcCAwEAAQJBAKcRxH9wuglYLBdI/0OT
BLzfWPZCEw1vZmMR2FF1Fm8nkNOVDPleeVGTWoOEcYYlQbpTmkGSxJ6ya+hqRi6x
goECIQDx3+X49fwpL6B5qpJIJMyZBSCuMhH4B7JevhGGFENi3wIhAMiNJN5Q3UkL
IuSvv03kaPR5XVQ99/UeEetUgGvBcABpAiBJSBzVITIVCGkGc7d+RCf49KTCIklv
bGWObufAR8Ni4QIgWpILjW8dkGg8GOUZ0zaNA6Nvt6TIv2UWGJ4v5PoV98kCIQDx
rIiZs5QbKdycsv9gQJzwQAogC8o04X3Zz3dsoX+h4A==
-----END RSA PRIVATE KEY-----`)

View File

@ -1,18 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package proxy provides transport and upgrade support for proxies.
package proxy // import "k8s.io/apimachinery/pkg/util/proxy"

View File

@ -1,259 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"k8s.io/klog"
"k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/sets"
)
// atomsToAttrs states which attributes of which tags require URL substitution.
// Sources: http://www.w3.org/TR/REC-html40/index/attributes.html
// http://www.w3.org/html/wg/drafts/html/master/index.html#attributes-1
var atomsToAttrs = map[atom.Atom]sets.String{
atom.A: sets.NewString("href"),
atom.Applet: sets.NewString("codebase"),
atom.Area: sets.NewString("href"),
atom.Audio: sets.NewString("src"),
atom.Base: sets.NewString("href"),
atom.Blockquote: sets.NewString("cite"),
atom.Body: sets.NewString("background"),
atom.Button: sets.NewString("formaction"),
atom.Command: sets.NewString("icon"),
atom.Del: sets.NewString("cite"),
atom.Embed: sets.NewString("src"),
atom.Form: sets.NewString("action"),
atom.Frame: sets.NewString("longdesc", "src"),
atom.Head: sets.NewString("profile"),
atom.Html: sets.NewString("manifest"),
atom.Iframe: sets.NewString("longdesc", "src"),
atom.Img: sets.NewString("longdesc", "src", "usemap"),
atom.Input: sets.NewString("src", "usemap", "formaction"),
atom.Ins: sets.NewString("cite"),
atom.Link: sets.NewString("href"),
atom.Object: sets.NewString("classid", "codebase", "data", "usemap"),
atom.Q: sets.NewString("cite"),
atom.Script: sets.NewString("src"),
atom.Source: sets.NewString("src"),
atom.Video: sets.NewString("poster", "src"),
// TODO: css URLs hidden in style elements.
}
// Transport is a transport for text/html content that replaces URLs in html
// content with the prefix of the proxy server
type Transport struct {
Scheme string
Host string
PathPrepend string
http.RoundTripper
}
// RoundTrip implements the http.RoundTripper interface
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// Add reverse proxy headers.
forwardedURI := path.Join(t.PathPrepend, req.URL.Path)
if strings.HasSuffix(req.URL.Path, "/") {
forwardedURI = forwardedURI + "/"
}
req.Header.Set("X-Forwarded-Uri", forwardedURI)
if len(t.Host) > 0 {
req.Header.Set("X-Forwarded-Host", t.Host)
}
if len(t.Scheme) > 0 {
req.Header.Set("X-Forwarded-Proto", t.Scheme)
}
rt := t.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
resp, err := rt.RoundTrip(req)
if err != nil {
message := fmt.Sprintf("Error: '%s'\nTrying to reach: '%v'", err.Error(), req.URL.String())
resp = &http.Response{
StatusCode: http.StatusServiceUnavailable,
Body: ioutil.NopCloser(strings.NewReader(message)),
}
return resp, nil
}
if redirect := resp.Header.Get("Location"); redirect != "" {
resp.Header.Set("Location", t.rewriteURL(redirect, req.URL, req.Host))
return resp, nil
}
cType := resp.Header.Get("Content-Type")
cType = strings.TrimSpace(strings.SplitN(cType, ";", 2)[0])
if cType != "text/html" {
// Do nothing, simply pass through
return resp, nil
}
return t.rewriteResponse(req, resp)
}
var _ = net.RoundTripperWrapper(&Transport{})
func (rt *Transport) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// rewriteURL rewrites a single URL to go through the proxy, if the URL refers
// to the same host as sourceURL, which is the page on which the target URL
// occurred, or if the URL matches the sourceRequestHost. If any error occurs (e.g.
// parsing), it returns targetURL.
func (t *Transport) rewriteURL(targetURL string, sourceURL *url.URL, sourceRequestHost string) string {
url, err := url.Parse(targetURL)
if err != nil {
return targetURL
}
// Example:
// When API server processes a proxy request to a service (e.g. /api/v1/namespace/foo/service/bar/proxy/),
// the sourceURL.Host (i.e. req.URL.Host) is the endpoint IP address of the service. The
// sourceRequestHost (i.e. req.Host) is the Host header that specifies the host on which the
// URL is sought, which can be different from sourceURL.Host. For example, if user sends the
// request through "kubectl proxy" locally (i.e. localhost:8001/api/v1/namespace/foo/service/bar/proxy/),
// sourceRequestHost is "localhost:8001".
//
// If the service's response URL contains non-empty host, and url.Host is equal to either sourceURL.Host
// or sourceRequestHost, we should not consider the returned URL to be a completely different host.
// It's the API server's responsibility to rewrite a same-host-and-absolute-path URL and append the
// necessary URL prefix (i.e. /api/v1/namespace/foo/service/bar/proxy/).
isDifferentHost := url.Host != "" && url.Host != sourceURL.Host && url.Host != sourceRequestHost
isRelative := !strings.HasPrefix(url.Path, "/")
if isDifferentHost || isRelative {
return targetURL
}
// Do not rewrite scheme and host if the Transport has empty scheme and host
// when targetURL already contains the sourceRequestHost
if !(url.Host == sourceRequestHost && t.Scheme == "" && t.Host == "") {
url.Scheme = t.Scheme
url.Host = t.Host
}
origPath := url.Path
// Do not rewrite URL if the sourceURL already contains the necessary prefix.
if strings.HasPrefix(url.Path, t.PathPrepend) {
return url.String()
}
url.Path = path.Join(t.PathPrepend, url.Path)
if strings.HasSuffix(origPath, "/") {
// Add back the trailing slash, which was stripped by path.Join().
url.Path += "/"
}
return url.String()
}
// rewriteHTML scans the HTML for tags with url-valued attributes, and updates
// those values with the urlRewriter function. The updated HTML is output to the
// writer.
func rewriteHTML(reader io.Reader, writer io.Writer, urlRewriter func(string) string) error {
// Note: This assumes the content is UTF-8.
tokenizer := html.NewTokenizer(reader)
var err error
for err == nil {
tokenType := tokenizer.Next()
switch tokenType {
case html.ErrorToken:
err = tokenizer.Err()
case html.StartTagToken, html.SelfClosingTagToken:
token := tokenizer.Token()
if urlAttrs, ok := atomsToAttrs[token.DataAtom]; ok {
for i, attr := range token.Attr {
if urlAttrs.Has(attr.Key) {
token.Attr[i].Val = urlRewriter(attr.Val)
}
}
}
_, err = writer.Write([]byte(token.String()))
default:
_, err = writer.Write(tokenizer.Raw())
}
}
if err != io.EOF {
return err
}
return nil
}
// rewriteResponse modifies an HTML response by updating absolute links referring
// to the original host to instead refer to the proxy transport.
func (t *Transport) rewriteResponse(req *http.Request, resp *http.Response) (*http.Response, error) {
origBody := resp.Body
defer origBody.Close()
newContent := &bytes.Buffer{}
var reader io.Reader = origBody
var writer io.Writer = newContent
encoding := resp.Header.Get("Content-Encoding")
switch encoding {
case "gzip":
var err error
reader, err = gzip.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("errorf making gzip reader: %v", err)
}
gzw := gzip.NewWriter(writer)
defer gzw.Close()
writer = gzw
// TODO: support flate, other encodings.
case "":
// This is fine
default:
// Some encoding we don't understand-- don't try to parse this
klog.Errorf("Proxy encountered encoding %v for text/html; can't understand this so not fixing links.", encoding)
return resp, nil
}
urlRewriter := func(targetUrl string) string {
return t.rewriteURL(targetUrl, req.URL, req.Host)
}
err := rewriteHTML(reader, writer, urlRewriter)
if err != nil {
klog.Errorf("Failed to rewrite URLs: %v", err)
return resp, err
}
resp.Body = ioutil.NopCloser(newContent)
// Update header node with new content-length
// TODO: Remove any hash/signature headers here?
resp.Header.Del("Content-Length")
resp.ContentLength = int64(newContent.Len())
return resp, err
}

View File

@ -1,276 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func parseURLOrDie(inURL string) *url.URL {
parsed, err := url.Parse(inURL)
if err != nil {
panic(err)
}
return parsed
}
func TestProxyTransport(t *testing.T) {
testTransport := &Transport{
Scheme: "http",
Host: "foo.com",
PathPrepend: "/proxy/node/node1:10250",
}
testTransport2 := &Transport{
Scheme: "https",
Host: "foo.com",
PathPrepend: "/proxy/node/node1:8080",
}
emptyHostTransport := &Transport{
Scheme: "https",
PathPrepend: "/proxy/node/node1:10250",
}
emptySchemeTransport := &Transport{
Host: "foo.com",
PathPrepend: "/proxy/node/node1:10250",
}
emptyHostAndSchemeTransport := &Transport{
PathPrepend: "/proxy/node/node1:10250",
}
type Item struct {
input string
sourceURL string
transport *Transport
output string
contentType string
forwardedURI string
redirect string
redirectWant string
reqHost string
}
table := map[string]Item{
"normal": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"full document": {
input: `<html><header></header><body><pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre></body></html>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<html><header></header><body><pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre></body></html>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"trailing slash": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log/">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log/">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"content-type charset": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html; charset=utf-8",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"content-type passthrough": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
contentType: "text/plain",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"subdir": {
input: `<a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a>`,
sourceURL: "http://mynode.com/whatever/apt/somelog.log",
transport: testTransport2,
output: `<a href="kubelet.log">kubelet.log</a><a href="https://foo.com/proxy/node/node1:8080/google.log">google.log</a>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:8080/whatever/apt/somelog.log",
},
"image": {
input: `<pre><img src="kubernetes.jpg"/><img src="/kubernetes_abs.jpg"/></pre>`,
sourceURL: "http://mynode.com/",
transport: testTransport,
output: `<pre><img src="kubernetes.jpg"/><img src="http://foo.com/proxy/node/node1:10250/kubernetes_abs.jpg"/></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/",
},
"abs": {
input: `<script src="http://google.com/kubernetes.js"/>`,
sourceURL: "http://mynode.com/any/path/",
transport: testTransport,
output: `<script src="http://google.com/kubernetes.js"/>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/any/path/",
},
"abs but same host": {
input: `<script src="http://mynode.com/kubernetes.js"/>`,
sourceURL: "http://mynode.com/any/path/",
transport: testTransport,
output: `<script src="http://foo.com/proxy/node/node1:10250/kubernetes.js"/>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/any/path/",
},
"redirect rel": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "/redirected/target/",
redirectWant: "http://foo.com/proxy/node/node1:10250/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"redirect abs same host": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "http://mynode.com/redirected/target/",
redirectWant: "http://foo.com/proxy/node/node1:10250/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"redirect abs other host": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "http://example.com/redirected/target/",
redirectWant: "http://example.com/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"redirect abs use reqHost no host no scheme": {
sourceURL: "http://mynode.com/redirect",
transport: emptyHostAndSchemeTransport,
redirect: "http://10.0.0.1:8001/redirected/target/",
redirectWant: "http://10.0.0.1:8001/proxy/node/node1:10250/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
reqHost: "10.0.0.1:8001",
},
"source contains the redirect already": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
sourceURL: "http://foo.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"no host": {
input: "<html></html>",
sourceURL: "http://mynode.com/logs/log.log",
transport: emptyHostTransport,
output: "<html></html>",
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"no scheme": {
input: "<html></html>",
sourceURL: "http://mynode.com/logs/log.log",
transport: emptySchemeTransport,
output: "<html></html>",
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
}
testItem := func(name string, item *Item) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check request headers.
if got, want := r.Header.Get("X-Forwarded-Uri"), item.forwardedURI; got != want {
t.Errorf("%v: X-Forwarded-Uri = %q, want %q", name, got, want)
}
if len(item.transport.Host) == 0 {
_, present := r.Header["X-Forwarded-Host"]
if present {
t.Errorf("%v: X-Forwarded-Host header should not be present", name)
}
} else {
if got, want := r.Header.Get("X-Forwarded-Host"), item.transport.Host; got != want {
t.Errorf("%v: X-Forwarded-Host = %q, want %q", name, got, want)
}
}
if len(item.transport.Scheme) == 0 {
_, present := r.Header["X-Forwarded-Proto"]
if present {
t.Errorf("%v: X-Forwarded-Proto header should not be present", name)
}
} else {
if got, want := r.Header.Get("X-Forwarded-Proto"), item.transport.Scheme; got != want {
t.Errorf("%v: X-Forwarded-Proto = %q, want %q", name, got, want)
}
}
// Send response.
if item.redirect != "" {
http.Redirect(w, r, item.redirect, http.StatusMovedPermanently)
return
}
w.Header().Set("Content-Type", item.contentType)
fmt.Fprint(w, item.input)
}))
defer server.Close()
// Replace source URL with our test server address.
sourceURL := parseURLOrDie(item.sourceURL)
serverURL := parseURLOrDie(server.URL)
item.input = strings.Replace(item.input, sourceURL.Host, serverURL.Host, -1)
item.redirect = strings.Replace(item.redirect, sourceURL.Host, serverURL.Host, -1)
sourceURL.Host = serverURL.Host
req, err := http.NewRequest("GET", sourceURL.String(), nil)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
if item.reqHost != "" {
req.Host = item.reqHost
}
resp, err := item.transport.RoundTrip(req)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
if item.redirect != "" {
// Check that redirect URLs get rewritten properly.
if got, want := resp.Header.Get("Location"), item.redirectWant; got != want {
t.Errorf("%v: Location header = %q, want %q", name, got, want)
}
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
if e, a := item.output, string(body); e != a {
t.Errorf("%v: expected %v, but got %v", name, e, a)
}
}
for name, item := range table {
testItem(name, &item)
}
}

View File

@ -1,466 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"github.com/mxk/go-flowrate/flowrate"
"k8s.io/klog"
)
// UpgradeRequestRoundTripper provides an additional method to decorate a request
// with any authentication or other protocol level information prior to performing
// an upgrade on the server. Any response will be handled by the intercepting
// proxy.
type UpgradeRequestRoundTripper interface {
http.RoundTripper
// WrapRequest takes a valid HTTP request and returns a suitably altered version
// of request with any HTTP level values required to complete the request half of
// an upgrade on the server. It does not get a chance to see the response and
// should bypass any request side logic that expects to see the response.
WrapRequest(*http.Request) (*http.Request, error)
}
// UpgradeAwareHandler is a handler for proxy requests that may require an upgrade
type UpgradeAwareHandler struct {
// UpgradeRequired will reject non-upgrade connections if true.
UpgradeRequired bool
// Location is the location of the upstream proxy. It is used as the location to Dial on the upstream server
// for upgrade requests unless UseRequestLocationOnUpgrade is true.
Location *url.URL
// Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used
Transport http.RoundTripper
// UpgradeTransport, if specified, will be used as the backend transport when upgrade requests are provided.
// This allows clients to disable HTTP/2.
UpgradeTransport UpgradeRequestRoundTripper
// WrapTransport indicates whether the provided Transport should be wrapped with default proxy transport behavior (URL rewriting, X-Forwarded-* header setting)
WrapTransport bool
// InterceptRedirects determines whether the proxy should sniff backend responses for redirects,
// following them as necessary.
InterceptRedirects bool
// RequireSameHostRedirects only allows redirects to the same host. It is only used if InterceptRedirects=true.
RequireSameHostRedirects bool
// UseRequestLocation will use the incoming request URL when talking to the backend server.
UseRequestLocation bool
// FlushInterval controls how often the standard HTTP proxy will flush content from the upstream.
FlushInterval time.Duration
// MaxBytesPerSec controls the maximum rate for an upstream connection. No rate is imposed if the value is zero.
MaxBytesPerSec int64
// Responder is passed errors that occur while setting up proxying.
Responder ErrorResponder
}
const defaultFlushInterval = 200 * time.Millisecond
// ErrorResponder abstracts error reporting to the proxy handler to remove the need to hardcode a particular
// error format.
type ErrorResponder interface {
Error(w http.ResponseWriter, req *http.Request, err error)
}
// SimpleErrorResponder is the legacy implementation of ErrorResponder for callers that only
// service a single request/response per proxy.
type SimpleErrorResponder interface {
Error(err error)
}
func NewErrorResponder(r SimpleErrorResponder) ErrorResponder {
return simpleResponder{r}
}
type simpleResponder struct {
responder SimpleErrorResponder
}
func (r simpleResponder) Error(w http.ResponseWriter, req *http.Request, err error) {
r.responder.Error(err)
}
// upgradeRequestRoundTripper implements proxy.UpgradeRequestRoundTripper.
type upgradeRequestRoundTripper struct {
http.RoundTripper
upgrader http.RoundTripper
}
var (
_ UpgradeRequestRoundTripper = &upgradeRequestRoundTripper{}
_ utilnet.RoundTripperWrapper = &upgradeRequestRoundTripper{}
)
// WrappedRoundTripper returns the round tripper that a caller would use.
func (rt *upgradeRequestRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// WriteToRequest calls the nested upgrader and then copies the returned request
// fields onto the passed request.
func (rt *upgradeRequestRoundTripper) WrapRequest(req *http.Request) (*http.Request, error) {
resp, err := rt.upgrader.RoundTrip(req)
if err != nil {
return nil, err
}
return resp.Request, nil
}
// onewayRoundTripper captures the provided request - which is assumed to have
// been modified by other round trippers - and then returns a fake response.
type onewayRoundTripper struct{}
// RoundTrip returns a simple 200 OK response that captures the provided request.
func (onewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(&bytes.Buffer{}),
Request: req,
}, nil
}
// MirrorRequest is a round tripper that can be called to get back the calling request as
// the core round tripper in a chain.
var MirrorRequest http.RoundTripper = onewayRoundTripper{}
// NewUpgradeRequestRoundTripper takes two round trippers - one for the underlying TCP connection, and
// one that is able to write headers to an HTTP request. The request rt is used to set the request headers
// and that is written to the underlying connection rt.
func NewUpgradeRequestRoundTripper(connection, request http.RoundTripper) UpgradeRequestRoundTripper {
return &upgradeRequestRoundTripper{
RoundTripper: connection,
upgrader: request,
}
}
// normalizeLocation returns the result of parsing the full URL, with scheme set to http if missing
func normalizeLocation(location *url.URL) *url.URL {
normalized, _ := url.Parse(location.String())
if len(normalized.Scheme) == 0 {
normalized.Scheme = "http"
}
return normalized
}
// NewUpgradeAwareHandler creates a new proxy handler with a default flush interval. Responder is required for returning
// errors to the caller.
func NewUpgradeAwareHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder ErrorResponder) *UpgradeAwareHandler {
return &UpgradeAwareHandler{
Location: normalizeLocation(location),
Transport: transport,
WrapTransport: wrapTransport,
UpgradeRequired: upgradeRequired,
FlushInterval: defaultFlushInterval,
Responder: responder,
}
}
// ServeHTTP handles the proxy request
func (h *UpgradeAwareHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if h.tryUpgrade(w, req) {
return
}
if h.UpgradeRequired {
h.Responder.Error(w, req, errors.NewBadRequest("Upgrade request required"))
return
}
loc := *h.Location
loc.RawQuery = req.URL.RawQuery
// If original request URL ended in '/', append a '/' at the end of the
// of the proxy URL
if !strings.HasSuffix(loc.Path, "/") && strings.HasSuffix(req.URL.Path, "/") {
loc.Path += "/"
}
// From pkg/genericapiserver/endpoints/handlers/proxy.go#ServeHTTP:
// Redirect requests with an empty path to a location that ends with a '/'
// This is essentially a hack for http://issue.k8s.io/4958.
// Note: Keep this code after tryUpgrade to not break that flow.
if len(loc.Path) == 0 {
var queryPart string
if len(req.URL.RawQuery) > 0 {
queryPart = "?" + req.URL.RawQuery
}
w.Header().Set("Location", req.URL.Path+"/"+queryPart)
w.WriteHeader(http.StatusMovedPermanently)
return
}
if h.Transport == nil || h.WrapTransport {
h.Transport = h.defaultProxyTransport(req.URL, h.Transport)
}
// WithContext creates a shallow clone of the request with the new context.
newReq := req.WithContext(context.Background())
newReq.Header = utilnet.CloneHeader(req.Header)
if !h.UseRequestLocation {
newReq.URL = &loc
}
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: h.Location.Scheme, Host: h.Location.Host})
proxy.Transport = h.Transport
proxy.FlushInterval = h.FlushInterval
proxy.ServeHTTP(w, newReq)
}
// tryUpgrade returns true if the request was handled.
func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Request) bool {
if !httpstream.IsUpgradeRequest(req) {
klog.V(6).Infof("Request was not an upgrade")
return false
}
var (
backendConn net.Conn
rawResponse []byte
err error
)
location := *h.Location
if h.UseRequestLocation {
location = *req.URL
location.Scheme = h.Location.Scheme
location.Host = h.Location.Host
}
clone := utilnet.CloneRequest(req)
// Only append X-Forwarded-For in the upgrade path, since httputil.NewSingleHostReverseProxy
// handles this in the non-upgrade path.
utilnet.AppendForwardedForHeader(clone)
if h.InterceptRedirects {
klog.V(6).Infof("Connecting to backend proxy (intercepting redirects) %s\n Headers: %v", &location, clone.Header)
backendConn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, &location, clone.Header, req.Body, utilnet.DialerFunc(h.DialForUpgrade), h.RequireSameHostRedirects)
} else {
klog.V(6).Infof("Connecting to backend proxy (direct dial) %s\n Headers: %v", &location, clone.Header)
clone.URL = &location
backendConn, err = h.DialForUpgrade(clone)
}
if err != nil {
klog.V(6).Infof("Proxy connection error: %v", err)
h.Responder.Error(w, req, err)
return true
}
defer backendConn.Close()
// determine the http response code from the backend by reading from rawResponse+backendConn
backendHTTPResponse, headerBytes, err := getResponse(io.MultiReader(bytes.NewReader(rawResponse), backendConn))
if err != nil {
klog.V(6).Infof("Proxy connection error: %v", err)
h.Responder.Error(w, req, err)
return true
}
if len(headerBytes) > len(rawResponse) {
// we read beyond the bytes stored in rawResponse, update rawResponse to the full set of bytes read from the backend
rawResponse = headerBytes
}
// Once the connection is hijacked, the ErrorResponder will no longer work, so
// hijacking should be the last step in the upgrade.
requestHijacker, ok := w.(http.Hijacker)
if !ok {
klog.V(6).Infof("Unable to hijack response writer: %T", w)
h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
return true
}
requestHijackedConn, _, err := requestHijacker.Hijack()
if err != nil {
klog.V(6).Infof("Unable to hijack response: %v", err)
h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
return true
}
defer requestHijackedConn.Close()
if backendHTTPResponse.StatusCode != http.StatusSwitchingProtocols {
// If the backend did not upgrade the request, echo the response from the backend to the client and return, closing the connection.
klog.V(6).Infof("Proxy upgrade error, status code %d", backendHTTPResponse.StatusCode)
// set read/write deadlines
deadline := time.Now().Add(10 * time.Second)
backendConn.SetReadDeadline(deadline)
requestHijackedConn.SetWriteDeadline(deadline)
// write the response to the client
err := backendHTTPResponse.Write(requestHijackedConn)
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
klog.Errorf("Error proxying data from backend to client: %v", err)
}
// Indicate we handled the request
return true
}
// Forward raw response bytes back to client.
if len(rawResponse) > 0 {
klog.V(6).Infof("Writing %d bytes to hijacked connection", len(rawResponse))
if _, err = requestHijackedConn.Write(rawResponse); err != nil {
utilruntime.HandleError(fmt.Errorf("Error proxying response from backend to client: %v", err))
}
}
// Proxy the connection. This is bidirectional, so we need a goroutine
// to copy in each direction. Once one side of the connection exits, we
// exit the function which performs cleanup and in the process closes
// the other half of the connection in the defer.
writerComplete := make(chan struct{})
readerComplete := make(chan struct{})
go func() {
var writer io.WriteCloser
if h.MaxBytesPerSec > 0 {
writer = flowrate.NewWriter(backendConn, h.MaxBytesPerSec)
} else {
writer = backendConn
}
_, err := io.Copy(writer, requestHijackedConn)
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
klog.Errorf("Error proxying data from client to backend: %v", err)
}
close(writerComplete)
}()
go func() {
var reader io.ReadCloser
if h.MaxBytesPerSec > 0 {
reader = flowrate.NewReader(backendConn, h.MaxBytesPerSec)
} else {
reader = backendConn
}
_, err := io.Copy(requestHijackedConn, reader)
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
klog.Errorf("Error proxying data from backend to client: %v", err)
}
close(readerComplete)
}()
// Wait for one half the connection to exit. Once it does the defer will
// clean up the other half of the connection.
select {
case <-writerComplete:
case <-readerComplete:
}
klog.V(6).Infof("Disconnecting from backend proxy %s\n Headers: %v", &location, clone.Header)
return true
}
func (h *UpgradeAwareHandler) Dial(req *http.Request) (net.Conn, error) {
return dial(req, h.Transport)
}
func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error) {
if h.UpgradeTransport == nil {
return dial(req, h.Transport)
}
updatedReq, err := h.UpgradeTransport.WrapRequest(req)
if err != nil {
return nil, err
}
return dial(updatedReq, h.UpgradeTransport)
}
// getResponseCode reads a http response from the given reader, returns the response,
// the bytes read from the reader, and any error encountered
func getResponse(r io.Reader) (*http.Response, []byte, error) {
rawResponse := bytes.NewBuffer(make([]byte, 0, 256))
// Save the bytes read while reading the response headers into the rawResponse buffer
resp, err := http.ReadResponse(bufio.NewReader(io.TeeReader(r, rawResponse)), nil)
if err != nil {
return nil, nil, err
}
// return the http response and the raw bytes consumed from the reader in the process
return resp, rawResponse.Bytes(), nil
}
// dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := DialURL(req.Context(), req.URL, transport)
if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err)
}
if err = req.Write(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("error sending request: %v", err)
}
return conn, err
}
var _ utilnet.Dialer = &UpgradeAwareHandler{}
func (h *UpgradeAwareHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper {
scheme := url.Scheme
host := url.Host
suffix := h.Location.Path
if strings.HasSuffix(url.Path, "/") && !strings.HasSuffix(suffix, "/") {
suffix += "/"
}
pathPrepend := strings.TrimSuffix(url.Path, suffix)
rewritingTransport := &Transport{
Scheme: scheme,
Host: host,
PathPrepend: pathPrepend,
RoundTripper: internalTransport,
}
return &corsRemovingTransport{
RoundTripper: rewritingTransport,
}
}
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
// from the internal response.
// Implements pkg/util/net.RoundTripperWrapper
type corsRemovingTransport struct {
http.RoundTripper
}
var _ = utilnet.RoundTripperWrapper(&corsRemovingTransport{})
func (rt *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
removeCORSHeaders(resp)
return resp, nil
}
func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// removeCORSHeaders strip CORS headers sent from the backend
// This should be called on all responses before returning
func removeCORSHeaders(resp *http.Response) {
resp.Header.Del("Access-Control-Allow-Credentials")
resp.Header.Del("Access-Control-Allow-Headers")
resp.Header.Del("Access-Control-Allow-Methods")
resp.Header.Del("Access-Control-Allow-Origin")
}

View File

@ -1,826 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
)
const fakeStatusCode = 567
type fakeResponder struct {
t *testing.T
called bool
err error
// called chan error
w http.ResponseWriter
}
func (r *fakeResponder) Error(w http.ResponseWriter, req *http.Request, err error) {
if r.called {
r.t.Errorf("Error responder called again!\nprevious error: %v\nnew error: %v", r.err, err)
}
w.WriteHeader(fakeStatusCode)
_, writeErr := w.Write([]byte(err.Error()))
assert.NoError(r.t, writeErr)
r.called = true
r.err = err
}
type fakeConn struct {
err error // The error to return when io is performed over the connection.
}
func (f *fakeConn) Read([]byte) (int, error) { return 0, f.err }
func (f *fakeConn) Write([]byte) (int, error) { return 0, f.err }
func (f *fakeConn) Close() error { return nil }
func (fakeConn) LocalAddr() net.Addr { return nil }
func (fakeConn) RemoteAddr() net.Addr { return nil }
func (fakeConn) SetDeadline(t time.Time) error { return nil }
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
type SimpleBackendHandler struct {
requestURL url.URL
requestHeader http.Header
requestBody []byte
requestMethod string
responseBody string
responseHeader map[string]string
t *testing.T
}
func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.requestURL = *req.URL
s.requestHeader = req.Header
s.requestMethod = req.Method
var err error
s.requestBody, err = ioutil.ReadAll(req.Body)
if err != nil {
s.t.Errorf("Unexpected error: %v", err)
return
}
if s.responseHeader != nil {
for k, v := range s.responseHeader {
w.Header().Add(k, v)
}
}
w.Write([]byte(s.responseBody))
}
func validateParameters(t *testing.T, name string, actual url.Values, expected map[string]string) {
for k, v := range expected {
actualValue, ok := actual[k]
if !ok {
t.Errorf("%s: Expected parameter %s not received", name, k)
continue
}
if actualValue[0] != v {
t.Errorf("%s: Parameter %s values don't match. Actual: %#v, Expected: %s",
name, k, actualValue, v)
}
}
}
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
for k, v := range expected {
actualValue, ok := actual[k]
if !ok {
t.Errorf("%s: Expected header %s not received", name, k)
continue
}
if actualValue[0] != v {
t.Errorf("%s: Header %s values don't match. Actual: %s, Expected: %s",
name, k, actualValue, v)
}
}
if notExpected == nil {
return
}
for _, h := range notExpected {
if _, present := actual[h]; present {
t.Errorf("%s: unexpected header: %s", name, h)
}
}
}
func TestServeHTTP(t *testing.T) {
tests := []struct {
name string
method string
requestPath string
expectedPath string
requestBody string
requestParams map[string]string
requestHeader map[string]string
responseHeader map[string]string
expectedRespHeader map[string]string
notExpectedRespHeader []string
upgradeRequired bool
expectError func(err error) bool
}{
{
name: "root path, simple get",
method: "GET",
requestPath: "/",
expectedPath: "/",
},
{
name: "no upgrade header sent",
method: "GET",
requestPath: "/",
upgradeRequired: true,
expectError: func(err error) bool {
return err != nil && strings.Contains(err.Error(), "Upgrade request required")
},
},
{
name: "simple path, get",
method: "GET",
requestPath: "/path/to/test",
expectedPath: "/path/to/test",
},
{
name: "request params",
method: "POST",
requestPath: "/some/path/",
expectedPath: "/some/path/",
requestParams: map[string]string{"param1": "value/1", "param2": "value%2"},
requestBody: "test request body",
},
{
name: "request headers",
method: "PUT",
requestPath: "/some/path",
expectedPath: "/some/path",
requestHeader: map[string]string{"Header1": "value1", "Header2": "value2"},
},
{
name: "empty path - slash should be added",
method: "GET",
requestPath: "",
expectedPath: "/",
},
{
name: "remove CORS headers",
method: "GET",
requestPath: "/some/path",
expectedPath: "/some/path",
responseHeader: map[string]string{
"Header1": "value1",
"Access-Control-Allow-Origin": "some.server",
"Access-Control-Allow-Methods": "GET"},
expectedRespHeader: map[string]string{
"Header1": "value1",
},
notExpectedRespHeader: []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
},
},
}
for i, test := range tests {
func() {
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
backendResponseHeader := test.responseHeader
// Test a simple header if not specified in the test
if backendResponseHeader == nil && test.expectedRespHeader == nil {
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
}
backendHandler := &SimpleBackendHandler{
responseBody: backendResponse,
responseHeader: backendResponseHeader,
}
backendServer := httptest.NewServer(backendHandler)
defer backendServer.Close()
responder := &fakeResponder{t: t}
backendURL, _ := url.Parse(backendServer.URL)
backendURL.Path = test.requestPath
proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, test.upgradeRequired, responder)
proxyServer := httptest.NewServer(proxyHandler)
defer proxyServer.Close()
proxyURL, _ := url.Parse(proxyServer.URL)
proxyURL.Path = test.requestPath
paramValues := url.Values{}
for k, v := range test.requestParams {
paramValues[k] = []string{v}
}
proxyURL.RawQuery = paramValues.Encode()
var requestBody io.Reader
if test.requestBody != "" {
requestBody = bytes.NewBufferString(test.requestBody)
}
req, err := http.NewRequest(test.method, proxyURL.String(), requestBody)
if test.requestHeader != nil {
header := http.Header{}
for k, v := range test.requestHeader {
header.Add(k, v)
}
req.Header = header
}
if err != nil {
t.Errorf("Error creating client request: %v", err)
}
client := &http.Client{}
res, err := client.Do(req)
if err != nil {
t.Errorf("Error from proxy request: %v", err)
}
if test.expectError != nil {
if !responder.called {
t.Errorf("%d: responder was not invoked", i)
return
}
if !test.expectError(responder.err) {
t.Errorf("%d: unexpected error: %v", i, responder.err)
}
return
}
// Validate backend request
// Method
if backendHandler.requestMethod != test.method {
t.Errorf("Unexpected request method: %s. Expected: %s",
backendHandler.requestMethod, test.method)
}
// Body
if string(backendHandler.requestBody) != test.requestBody {
t.Errorf("Unexpected request body: %s. Expected: %s",
string(backendHandler.requestBody), test.requestBody)
}
// Path
if backendHandler.requestURL.Path != test.expectedPath {
t.Errorf("Unexpected request path: %s", backendHandler.requestURL.Path)
}
// Parameters
validateParameters(t, test.name, backendHandler.requestURL.Query(), test.requestParams)
// Headers
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
test.requestHeader, nil)
// Validate proxy response
// Response Headers
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
// Validate Body
responseBody, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Unexpected error reading response body: %v", err)
}
if rb := string(responseBody); rb != backendResponse {
t.Errorf("Did not get expected response body: %s. Expected: %s", rb, backendResponse)
}
// Error
if responder.called {
t.Errorf("Unexpected proxy handler error: %v", responder.err)
}
}()
}
}
type RoundTripperFunc func(req *http.Request) (*http.Response, error)
func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestProxyUpgrade(t *testing.T) {
localhostPool := x509.NewCertPool()
if !localhostPool.AppendCertsFromPEM(localhostCert) {
t.Errorf("error setting up localhostCert pool")
}
var d net.Dialer
testcases := map[string]struct {
ServerFunc func(http.Handler) *httptest.Server
ProxyTransport http.RoundTripper
UpgradeTransport UpgradeRequestRoundTripper
ExpectedAuth string
}{
"http": {
ServerFunc: httptest.NewServer,
ProxyTransport: nil,
},
"https (invalid hostname + InsecureSkipVerify)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(exampleCert, exampleKey)
if err != nil {
t.Errorf("https (invalid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}),
},
"https (valid hostname + RootCAs)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
},
"https (valid hostname + RootCAs + custom dialer)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
},
"https (valid hostname + RootCAs + custom dialer + bearer token)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
UpgradeTransport: NewUpgradeRequestRoundTripper(
utilnet.SetOldTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = utilnet.CloneRequest(req)
req.Header.Set("Authorization", "Bearer 1234")
return MirrorRequest.RoundTrip(req)
}),
),
ExpectedAuth: "Bearer 1234",
},
}
for k, tc := range testcases {
for _, redirect := range []bool{false, true} {
tcName := k
backendPath := "/hello"
if redirect {
tcName += " with redirect"
backendPath = "/redirect"
}
func() { // Cleanup after each test case.
backend := http.NewServeMux()
backend.Handle("/hello", websocket.Handler(func(ws *websocket.Conn) {
if ws.Request().Header.Get("Authorization") != tc.ExpectedAuth {
t.Errorf("%s: unexpected headers on request: %v", k, ws.Request().Header)
defer ws.Close()
ws.Write([]byte("you failed"))
return
}
defer ws.Close()
body := make([]byte, 5)
ws.Read(body)
ws.Write([]byte("hello " + string(body)))
}))
backend.Handle("/redirect", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/hello", http.StatusFound)
}))
backendServer := tc.ServerFunc(backend)
defer backendServer.Close()
serverURL, _ := url.Parse(backendServer.URL)
serverURL.Path = backendPath
proxyHandler := NewUpgradeAwareHandler(serverURL, tc.ProxyTransport, false, false, &noErrorsAllowed{t: t})
proxyHandler.UpgradeTransport = tc.UpgradeTransport
proxyHandler.InterceptRedirects = redirect
proxy := httptest.NewServer(proxyHandler)
defer proxy.Close()
ws, err := websocket.Dial("ws://"+proxy.Listener.Addr().String()+"/some/path", "", "http://127.0.0.1/")
if err != nil {
t.Fatalf("%s: websocket dial err: %s", tcName, err)
}
defer ws.Close()
if _, err := ws.Write([]byte("world")); err != nil {
t.Fatalf("%s: write err: %s", tcName, err)
}
response := make([]byte, 20)
n, err := ws.Read(response)
if err != nil {
t.Fatalf("%s: read err: %s", tcName, err)
}
if e, a := "hello world", string(response[0:n]); e != a {
t.Fatalf("%s: expected '%#v', got '%#v'", tcName, e, a)
}
}()
}
}
}
type noErrorsAllowed struct {
t *testing.T
}
func (r *noErrorsAllowed) Error(w http.ResponseWriter, req *http.Request, err error) {
r.t.Error(err)
}
func TestProxyUpgradeErrorResponse(t *testing.T) {
var (
responder *fakeResponder
expectedErr = errors.New("EXPECTED")
)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return &fakeConn{err: expectedErr}, nil
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
responder = &fakeResponder{t: t, w: w}
proxyHandler := NewUpgradeAwareHandler(
&url.URL{
Host: "fake-backend",
},
transport,
false,
true,
responder,
)
proxyHandler.ServeHTTP(w, r)
}))
defer proxy.Close()
// Send request to proxy server.
req, err := http.NewRequest("POST", "http://"+proxy.Listener.Addr().String()+"/some/path", nil)
require.NoError(t, err)
req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Expect error response.
assert.True(t, responder.called)
assert.Equal(t, fakeStatusCode, resp.StatusCode)
msg, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(msg), expectedErr.Error())
}
func TestDefaultProxyTransport(t *testing.T) {
tests := []struct {
name,
url,
location,
expectedScheme,
expectedHost,
expectedPathPrepend string
}{
{
name: "simple path",
url: "http://test.server:8080/a/test/location",
location: "http://localhost/location",
expectedScheme: "http",
expectedHost: "test.server:8080",
expectedPathPrepend: "/a/test",
},
{
name: "empty path",
url: "http://test.server:8080/a/test/",
location: "http://localhost",
expectedScheme: "http",
expectedHost: "test.server:8080",
expectedPathPrepend: "/a/test",
},
{
name: "location ending in slash",
url: "http://test.server:8080/a/test/",
location: "http://localhost/",
expectedScheme: "http",
expectedHost: "test.server:8080",
expectedPathPrepend: "/a/test",
},
}
for _, test := range tests {
locURL, _ := url.Parse(test.location)
URL, _ := url.Parse(test.url)
h := NewUpgradeAwareHandler(locURL, nil, false, false, nil)
result := h.defaultProxyTransport(URL, nil)
transport := result.(*corsRemovingTransport).RoundTripper.(*Transport)
if transport.Scheme != test.expectedScheme {
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
}
if transport.Host != test.expectedHost {
t.Errorf("%s: unexpected host. Actual: %s, Expected: %s", test.name, transport.Host, test.expectedHost)
}
if transport.PathPrepend != test.expectedPathPrepend {
t.Errorf("%s: unexpected path prepend. Actual: %s, Expected: %s", test.name, transport.PathPrepend, test.expectedPathPrepend)
}
}
}
func TestProxyRequestContentLengthAndTransferEncoding(t *testing.T) {
chunk := func(data []byte) []byte {
out := &bytes.Buffer{}
chunker := httputil.NewChunkedWriter(out)
for _, b := range data {
if _, err := chunker.Write([]byte{b}); err != nil {
panic(err)
}
}
chunker.Close()
out.Write([]byte("\r\n"))
return out.Bytes()
}
zip := func(data []byte) []byte {
out := &bytes.Buffer{}
zipper := gzip.NewWriter(out)
if _, err := zipper.Write(data); err != nil {
panic(err)
}
zipper.Close()
return out.Bytes()
}
sampleData := []byte("abcde")
table := map[string]struct {
reqHeaders http.Header
reqBody []byte
expectedHeaders http.Header
expectedBody []byte
}{
"content-length": {
reqHeaders: http.Header{
"Content-Length": []string{"5"},
},
reqBody: sampleData,
expectedHeaders: http.Header{
"Content-Length": []string{"5"},
"Content-Encoding": nil, // none set
"Transfer-Encoding": nil, // none set
},
expectedBody: sampleData,
},
"content-length + identity transfer-encoding": {
reqHeaders: http.Header{
"Content-Length": []string{"5"},
"Transfer-Encoding": []string{"identity"},
},
reqBody: sampleData,
expectedHeaders: http.Header{
"Content-Length": []string{"5"},
"Content-Encoding": nil, // none set
"Transfer-Encoding": nil, // gets removed
},
expectedBody: sampleData,
},
"content-length + gzip content-encoding": {
reqHeaders: http.Header{
"Content-Length": []string{strconv.Itoa(len(zip(sampleData)))},
"Content-Encoding": []string{"gzip"},
},
reqBody: zip(sampleData),
expectedHeaders: http.Header{
"Content-Length": []string{strconv.Itoa(len(zip(sampleData)))},
"Content-Encoding": []string{"gzip"},
"Transfer-Encoding": nil, // none set
},
expectedBody: zip(sampleData),
},
"chunked transfer-encoding": {
reqHeaders: http.Header{
"Transfer-Encoding": []string{"chunked"},
},
reqBody: chunk(sampleData),
expectedHeaders: http.Header{
"Content-Length": nil, // none set
"Content-Encoding": nil, // none set
"Transfer-Encoding": nil, // Transfer-Encoding gets removed
},
expectedBody: sampleData, // sample data is unchunked
},
"chunked transfer-encoding + gzip content-encoding": {
reqHeaders: http.Header{
"Content-Encoding": []string{"gzip"},
"Transfer-Encoding": []string{"chunked"},
},
reqBody: chunk(zip(sampleData)),
expectedHeaders: http.Header{
"Content-Length": nil, // none set
"Content-Encoding": []string{"gzip"},
"Transfer-Encoding": nil, // gets removed
},
expectedBody: zip(sampleData), // sample data is unchunked, but content-encoding is preserved
},
// "Transfer-Encoding: gzip" is not supported by go
// See http/transfer.go#fixTransferEncoding (https://golang.org/src/net/http/transfer.go#L427)
// Once it is supported, this test case should succeed
//
// "gzip+chunked transfer-encoding": {
// reqHeaders: http.Header{
// "Transfer-Encoding": []string{"chunked,gzip"},
// },
// reqBody: chunk(zip(sampleData)),
//
// expectedHeaders: http.Header{
// "Content-Length": nil, // no content-length headers
// "Transfer-Encoding": nil, // Transfer-Encoding gets removed
// },
// expectedBody: sampleData,
// },
}
successfulResponse := "backend passed tests"
for k, item := range table {
// Start the downstream server
downstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Verify headers
for header, v := range item.expectedHeaders {
if !reflect.DeepEqual(v, req.Header[header]) {
t.Errorf("%s: Expected headers for %s to be %v, got %v", k, header, v, req.Header[header])
}
}
// Read body
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Errorf("%s: unexpected error %v", k, err)
}
req.Body.Close()
// Verify length
if req.ContentLength > 0 && req.ContentLength != int64(len(body)) {
t.Errorf("%s: ContentLength was %d, len(data) was %d", k, req.ContentLength, len(body))
}
// Verify content
if !bytes.Equal(item.expectedBody, body) {
t.Errorf("%s: Expected %q, got %q", k, string(item.expectedBody), string(body))
}
// Write successful response
w.Write([]byte(successfulResponse))
}))
defer downstreamServer.Close()
responder := &fakeResponder{t: t}
backendURL, _ := url.Parse(downstreamServer.URL)
proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, false, responder)
proxyServer := httptest.NewServer(proxyHandler)
defer proxyServer.Close()
// Dial the proxy server
conn, err := net.Dial(proxyServer.Listener.Addr().Network(), proxyServer.Listener.Addr().String())
if err != nil {
t.Errorf("unexpected error %v", err)
continue
}
defer conn.Close()
// Add standard http 1.1 headers
if item.reqHeaders == nil {
item.reqHeaders = http.Header{}
}
item.reqHeaders.Add("Connection", "close")
item.reqHeaders.Add("Host", proxyServer.Listener.Addr().String())
// Write the request headers
if _, err := fmt.Fprint(conn, "POST / HTTP/1.1\r\n"); err != nil {
t.Fatalf("%s unexpected error %v", k, err)
}
for header, values := range item.reqHeaders {
for _, value := range values {
if _, err := fmt.Fprintf(conn, "%s: %s\r\n", header, value); err != nil {
t.Fatalf("%s: unexpected error %v", k, err)
}
}
}
// Header separator
if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
t.Fatalf("%s: unexpected error %v", k, err)
}
// Body
if _, err := conn.Write(item.reqBody); err != nil {
t.Fatalf("%s: unexpected error %v", k, err)
}
// Read response
response, err := ioutil.ReadAll(conn)
if err != nil {
t.Errorf("%s: unexpected error %v", k, err)
continue
}
if !strings.HasSuffix(string(response), successfulResponse) {
t.Errorf("%s: Did not get successful response: %s", k, string(response))
continue
}
}
}
// exampleCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 512 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var exampleCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBdzCCASGgAwIBAgIRAOVTAdPnfbS5V85mfS90TfIwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzBcMA0GCSqGSIb3DQEBAQUAA0sAMEgC
QQCoVSqeu8TBvF+70T7Jm4340YQNhds6IxjRoifenYodAO1dnKGrcbF266DJGunh
nIjQH7B12tduhl0fLK4Ezf7/AgMBAAGjUDBOMA4GA1UdDwEB/wQEAwICpDATBgNV
HSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBYGA1UdEQQPMA2CC2V4
YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA0EAk1kVa5uZ/AzwYDVcS9bpM/czwjjV
xq3VeSCfmNa2uNjbFvodmCRwZOHUvipAMGCUCV6j5vMrJ8eMj8tCQ36W9A==
-----END CERTIFICATE-----`)
var exampleKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAKhVKp67xMG8X7vRPsmbjfjRhA2F2zojGNGiJ96dih0A7V2coatx
sXbroMka6eGciNAfsHXa126GXR8srgTN/v8CAwEAAQJASdzdD7vKsUwMIejGCUb1
fAnLTPfAY3lFCa+CmR89nE22dAoRDv+5RbnBsZ58BazPNJHrsVPRlfXB3OQmSQr0
SQIhANoJhs+xOJE/i8nJv0uAbzKyiD1YkvRkta0GpUOULyAVAiEAxaQus3E/SuqD
P7y5NeJnE7X6XkyC35zrsJRkz7orE8MCIHdDjsI8pjyNDeGqwUCDWE/a6DrmIDwe
emHSqMN2YvChAiEAnxLCM9NWaenOsaIoP+J1rDuvw+4499nJKVqGuVrSCRkCIEqK
4KSchPMc3x8M/uhw9oWTtKFmjA/PPh0FsWCdKrEy
-----END RSA PRIVATE KEY-----`)

View File

@ -1,127 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package rand provides utilities related to randomization.
package rand
import (
"math/rand"
"sync"
"time"
)
var rng = struct {
sync.Mutex
rand *rand.Rand
}{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
// Int returns a non-negative pseudo-random int.
func Int() int {
rng.Lock()
defer rng.Unlock()
return rng.rand.Int()
}
// Intn generates an integer in range [0,max).
// By design this should panic if input is invalid, <= 0.
func Intn(max int) int {
rng.Lock()
defer rng.Unlock()
return rng.rand.Intn(max)
}
// IntnRange generates an integer in range [min,max).
// By design this should panic if input is invalid, <= 0.
func IntnRange(min, max int) int {
rng.Lock()
defer rng.Unlock()
return rng.rand.Intn(max-min) + min
}
// IntnRange generates an int64 integer in range [min,max).
// By design this should panic if input is invalid, <= 0.
func Int63nRange(min, max int64) int64 {
rng.Lock()
defer rng.Unlock()
return rng.rand.Int63n(max-min) + min
}
// Seed seeds the rng with the provided seed.
func Seed(seed int64) {
rng.Lock()
defer rng.Unlock()
rng.rand = rand.New(rand.NewSource(seed))
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n)
// from the default Source.
func Perm(n int) []int {
rng.Lock()
defer rng.Unlock()
return rng.rand.Perm(n)
}
const (
// We omit vowels from the set of available characters to reduce the chances
// of "bad words" being formed.
alphanums = "bcdfghjklmnpqrstvwxz2456789"
// No. of bits required to index into alphanums string.
alphanumsIdxBits = 5
// Mask used to extract last alphanumsIdxBits of an int.
alphanumsIdxMask = 1<<alphanumsIdxBits - 1
// No. of random letters we can extract from a single int63.
maxAlphanumsPerInt = 63 / alphanumsIdxBits
)
// String generates a random alphanumeric string, without vowels, which is n
// characters long. This will panic if n is less than zero.
// How the random string is created:
// - we generate random int63's
// - from each int63, we are extracting multiple random letters by bit-shifting and masking
// - if some index is out of range of alphanums we neglect it (unlikely to happen multiple times in a row)
func String(n int) string {
b := make([]byte, n)
rng.Lock()
defer rng.Unlock()
randomInt63 := rng.rand.Int63()
remaining := maxAlphanumsPerInt
for i := 0; i < n; {
if remaining == 0 {
randomInt63, remaining = rng.rand.Int63(), maxAlphanumsPerInt
}
if idx := int(randomInt63 & alphanumsIdxMask); idx < len(alphanums) {
b[i] = alphanums[idx]
i++
}
randomInt63 >>= alphanumsIdxBits
remaining--
}
return string(b)
}
// SafeEncodeString encodes s using the same characters as rand.String. This reduces the chances of bad words and
// ensures that strings generated from hash functions appear consistent throughout the API.
func SafeEncodeString(s string) string {
r := make([]byte, len(s))
for i, b := range []rune(s) {
r[i] = alphanums[(int(b) % len(alphanums))]
}
return string(r)
}

View File

@ -1,114 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package rand
import (
"math/rand"
"strings"
"testing"
)
const (
maxRangeTestCount = 500
testStringLength = 32
)
func TestString(t *testing.T) {
valid := "0123456789abcdefghijklmnopqrstuvwxyz"
for _, l := range []int{0, 1, 2, 10, 123} {
s := String(l)
if len(s) != l {
t.Errorf("expected string of size %d, got %q", l, s)
}
for _, c := range s {
if !strings.ContainsRune(valid, c) {
t.Errorf("expected valid characters, got %v", c)
}
}
}
}
// Confirm that panic occurs on invalid input.
func TestRangePanic(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Errorf("Panic didn't occur!")
}
}()
// Should result in an error...
Intn(0)
}
func TestIntn(t *testing.T) {
// 0 is invalid.
for _, max := range []int{1, 2, 10, 123} {
inrange := Intn(max)
if inrange < 0 || inrange > max {
t.Errorf("%v out of range (0,%v)", inrange, max)
}
}
}
func TestPerm(t *testing.T) {
Seed(5)
rand.Seed(5)
for i := 1; i < 20; i++ {
actual := Perm(i)
expected := rand.Perm(i)
for j := 0; j < i; j++ {
if actual[j] != expected[j] {
t.Errorf("Perm call result is unexpected")
}
}
}
}
func TestIntnRange(t *testing.T) {
// 0 is invalid.
for min, max := range map[int]int{1: 2, 10: 123, 100: 500} {
for i := 0; i < maxRangeTestCount; i++ {
inrange := IntnRange(min, max)
if inrange < min || inrange >= max {
t.Errorf("%v out of range (%v,%v)", inrange, min, max)
}
}
}
}
func TestInt63nRange(t *testing.T) {
// 0 is invalid.
for min, max := range map[int64]int64{1: 2, 10: 123, 100: 500} {
for i := 0; i < maxRangeTestCount; i++ {
inrange := Int63nRange(min, max)
if inrange < min || inrange >= max {
t.Errorf("%v out of range (%v,%v)", inrange, min, max)
}
}
}
}
func BenchmarkRandomStringGeneration(b *testing.B) {
b.ResetTimer()
var s string
for i := 0; i < b.N; i++ {
s = String(testStringLength)
}
b.StopTimer()
if len(s) == 0 {
b.Fatal(s)
}
}

View File

@ -1,53 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
const (
DefaultStreamCreationTimeout = 30 * time.Second
// The SPDY subprotocol "channel.k8s.io" is used for remote command
// attachment/execution. This represents the initial unversioned subprotocol,
// which has the known bugs http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395.
StreamProtocolV1Name = "channel.k8s.io"
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
// attachment/execution. It is the second version of the subprotocol and
// resolves the issues present in the first version.
StreamProtocolV2Name = "v2.channel.k8s.io"
// The SPDY subprotocol "v3.channel.k8s.io" is used for remote command
// attachment/execution. It is the third version of the subprotocol and
// adds support for resizing container terminals.
StreamProtocolV3Name = "v3.channel.k8s.io"
// The SPDY subprotocol "v4.channel.k8s.io" is used for remote command
// attachment/execution. It is the 4th version of the subprotocol and
// adds support for exit codes.
StreamProtocolV4Name = "v4.channel.k8s.io"
NonZeroExitCodeReason = metav1.StatusReason("NonZeroExitCode")
ExitCodeCauseType = metav1.CauseType("ExitCode")
)
var SupportedStreamingProtocols = []string{StreamProtocolV4Name, StreamProtocolV3Name, StreamProtocolV2Name, StreamProtocolV1Name}

View File

@ -1,71 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package runtime
import (
"fmt"
"testing"
)
func TestHandleCrash(t *testing.T) {
defer func() {
if x := recover(); x == nil {
t.Errorf("Expected a panic to recover from")
}
}()
defer HandleCrash()
panic("Test Panic")
}
func TestCustomHandleCrash(t *testing.T) {
old := PanicHandlers
defer func() { PanicHandlers = old }()
var result interface{}
PanicHandlers = []func(interface{}){
func(r interface{}) {
result = r
},
}
func() {
defer func() {
if x := recover(); x == nil {
t.Errorf("Expected a panic to recover from")
}
}()
defer HandleCrash()
panic("test")
}()
if result != "test" {
t.Errorf("did not receive custom handler")
}
}
func TestCustomHandleError(t *testing.T) {
old := ErrorHandlers
defer func() { ErrorHandlers = old }()
var result error
ErrorHandlers = []func(error){
func(err error) {
result = err
},
}
err := fmt.Errorf("test")
HandleError(err)
if result != err {
t.Errorf("did not receive custom handler")
}
}

View File

@ -1,270 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sets
import (
"reflect"
"testing"
)
func TestStringSet(t *testing.T) {
s := String{}
s2 := String{}
if len(s) != 0 {
t.Errorf("Expected len=0: %d", len(s))
}
s.Insert("a", "b")
if len(s) != 2 {
t.Errorf("Expected len=2: %d", len(s))
}
s.Insert("c")
if s.Has("d") {
t.Errorf("Unexpected contents: %#v", s)
}
if !s.Has("a") {
t.Errorf("Missing contents: %#v", s)
}
s.Delete("a")
if s.Has("a") {
t.Errorf("Unexpected contents: %#v", s)
}
s.Insert("a")
if s.HasAll("a", "b", "d") {
t.Errorf("Unexpected contents: %#v", s)
}
if !s.HasAll("a", "b") {
t.Errorf("Missing contents: %#v", s)
}
s2.Insert("a", "b", "d")
if s.IsSuperset(s2) {
t.Errorf("Unexpected contents: %#v", s)
}
s2.Delete("d")
if !s.IsSuperset(s2) {
t.Errorf("Missing contents: %#v", s)
}
}
func TestStringSetDeleteMultiples(t *testing.T) {
s := String{}
s.Insert("a", "b", "c")
if len(s) != 3 {
t.Errorf("Expected len=3: %d", len(s))
}
s.Delete("a", "c")
if len(s) != 1 {
t.Errorf("Expected len=1: %d", len(s))
}
if s.Has("a") {
t.Errorf("Unexpected contents: %#v", s)
}
if s.Has("c") {
t.Errorf("Unexpected contents: %#v", s)
}
if !s.Has("b") {
t.Errorf("Missing contents: %#v", s)
}
}
func TestNewStringSet(t *testing.T) {
s := NewString("a", "b", "c")
if len(s) != 3 {
t.Errorf("Expected len=3: %d", len(s))
}
if !s.Has("a") || !s.Has("b") || !s.Has("c") {
t.Errorf("Unexpected contents: %#v", s)
}
}
func TestStringSetList(t *testing.T) {
s := NewString("z", "y", "x", "a")
if !reflect.DeepEqual(s.List(), []string{"a", "x", "y", "z"}) {
t.Errorf("List gave unexpected result: %#v", s.List())
}
}
func TestStringSetDifference(t *testing.T) {
a := NewString("1", "2", "3")
b := NewString("1", "2", "4", "5")
c := a.Difference(b)
d := b.Difference(a)
if len(c) != 1 {
t.Errorf("Expected len=1: %d", len(c))
}
if !c.Has("3") {
t.Errorf("Unexpected contents: %#v", c.List())
}
if len(d) != 2 {
t.Errorf("Expected len=2: %d", len(d))
}
if !d.Has("4") || !d.Has("5") {
t.Errorf("Unexpected contents: %#v", d.List())
}
}
func TestStringSetHasAny(t *testing.T) {
a := NewString("1", "2", "3")
if !a.HasAny("1", "4") {
t.Errorf("expected true, got false")
}
if a.HasAny("0", "4") {
t.Errorf("expected false, got true")
}
}
func TestStringSetEquals(t *testing.T) {
// Simple case (order doesn't matter)
a := NewString("1", "2")
b := NewString("2", "1")
if !a.Equal(b) {
t.Errorf("Expected to be equal: %v vs %v", a, b)
}
// It is a set; duplicates are ignored
b = NewString("2", "2", "1")
if !a.Equal(b) {
t.Errorf("Expected to be equal: %v vs %v", a, b)
}
// Edge cases around empty sets / empty strings
a = NewString()
b = NewString()
if !a.Equal(b) {
t.Errorf("Expected to be equal: %v vs %v", a, b)
}
b = NewString("1", "2", "3")
if a.Equal(b) {
t.Errorf("Expected to be not-equal: %v vs %v", a, b)
}
b = NewString("1", "2", "")
if a.Equal(b) {
t.Errorf("Expected to be not-equal: %v vs %v", a, b)
}
// Check for equality after mutation
a = NewString()
a.Insert("1")
if a.Equal(b) {
t.Errorf("Expected to be not-equal: %v vs %v", a, b)
}
a.Insert("2")
if a.Equal(b) {
t.Errorf("Expected to be not-equal: %v vs %v", a, b)
}
a.Insert("")
if !a.Equal(b) {
t.Errorf("Expected to be equal: %v vs %v", a, b)
}
a.Delete("")
if a.Equal(b) {
t.Errorf("Expected to be not-equal: %v vs %v", a, b)
}
}
func TestStringUnion(t *testing.T) {
tests := []struct {
s1 String
s2 String
expected String
}{
{
NewString("1", "2", "3", "4"),
NewString("3", "4", "5", "6"),
NewString("1", "2", "3", "4", "5", "6"),
},
{
NewString("1", "2", "3", "4"),
NewString(),
NewString("1", "2", "3", "4"),
},
{
NewString(),
NewString("1", "2", "3", "4"),
NewString("1", "2", "3", "4"),
},
{
NewString(),
NewString(),
NewString(),
},
}
for _, test := range tests {
union := test.s1.Union(test.s2)
if union.Len() != test.expected.Len() {
t.Errorf("Expected union.Len()=%d but got %d", test.expected.Len(), union.Len())
}
if !union.Equal(test.expected) {
t.Errorf("Expected union.Equal(expected) but not true. union:%v expected:%v", union.List(), test.expected.List())
}
}
}
func TestStringIntersection(t *testing.T) {
tests := []struct {
s1 String
s2 String
expected String
}{
{
NewString("1", "2", "3", "4"),
NewString("3", "4", "5", "6"),
NewString("3", "4"),
},
{
NewString("1", "2", "3", "4"),
NewString("1", "2", "3", "4"),
NewString("1", "2", "3", "4"),
},
{
NewString("1", "2", "3", "4"),
NewString(),
NewString(),
},
{
NewString(),
NewString("1", "2", "3", "4"),
NewString(),
},
{
NewString(),
NewString(),
NewString(),
},
}
for _, test := range tests {
intersection := test.s1.Intersection(test.s2)
if intersection.Len() != test.expected.Len() {
t.Errorf("Expected intersection.Len()=%d but got %d", test.expected.Len(), intersection.Len())
}
if !intersection.Equal(test.expected) {
t.Errorf("Expected intersection.Equal(expected) but not true. intersection:%v expected:%v", intersection.List(), test.expected.List())
}
}
}

View File

@ -1,32 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package types just provides input types to the set generator. It also
// contains a "go generate" block.
// (You must first `go install k8s.io/code-generator/cmd/set-gen`)
package types
//go:generate set-gen -i k8s.io/kubernetes/pkg/util/sets/types
type ReferenceSetTypes struct {
// These types all cause files to be generated.
// These types should be reflected in the output of
// the "//pkg/util/sets:set-gen" genrule.
a int64
b int
c byte
d string
}

View File

@ -1,6 +0,0 @@
approvers:
- pwittrock
- mengqiy
reviewers:
- mengqiy
- apelisse

View File

@ -1,49 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package strategicpatch
import (
"fmt"
)
type LookupPatchMetaError struct {
Path string
Err error
}
func (e LookupPatchMetaError) Error() string {
return fmt.Sprintf("LookupPatchMetaError(%s): %v", e.Path, e.Err)
}
type FieldNotFoundError struct {
Path string
Field string
}
func (e FieldNotFoundError) Error() string {
return fmt.Sprintf("unable to find api field %q in %s", e.Field, e.Path)
}
type InvalidTypeError struct {
Path string
Expected string
Actual string
}
func (e InvalidTypeError) Error() string {
return fmt.Sprintf("invalid type for %s: got %q, expected %q", e.Path, e.Actual, e.Expected)
}

View File

@ -1,194 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package strategicpatch
import (
"errors"
"fmt"
"reflect"
"k8s.io/apimachinery/pkg/util/mergepatch"
forkedjson "k8s.io/apimachinery/third_party/forked/golang/json"
openapi "k8s.io/kube-openapi/pkg/util/proto"
)
type PatchMeta struct {
patchStrategies []string
patchMergeKey string
}
func (pm PatchMeta) GetPatchStrategies() []string {
if pm.patchStrategies == nil {
return []string{}
}
return pm.patchStrategies
}
func (pm PatchMeta) SetPatchStrategies(ps []string) {
pm.patchStrategies = ps
}
func (pm PatchMeta) GetPatchMergeKey() string {
return pm.patchMergeKey
}
func (pm PatchMeta) SetPatchMergeKey(pmk string) {
pm.patchMergeKey = pmk
}
type LookupPatchMeta interface {
// LookupPatchMetadataForStruct gets subschema and the patch metadata (e.g. patch strategy and merge key) for map.
LookupPatchMetadataForStruct(key string) (LookupPatchMeta, PatchMeta, error)
// LookupPatchMetadataForSlice get subschema and the patch metadata for slice.
LookupPatchMetadataForSlice(key string) (LookupPatchMeta, PatchMeta, error)
// Get the type name of the field
Name() string
}
type PatchMetaFromStruct struct {
T reflect.Type
}
func NewPatchMetaFromStruct(dataStruct interface{}) (PatchMetaFromStruct, error) {
t, err := getTagStructType(dataStruct)
return PatchMetaFromStruct{T: t}, err
}
var _ LookupPatchMeta = PatchMetaFromStruct{}
func (s PatchMetaFromStruct) LookupPatchMetadataForStruct(key string) (LookupPatchMeta, PatchMeta, error) {
fieldType, fieldPatchStrategies, fieldPatchMergeKey, err := forkedjson.LookupPatchMetadataForStruct(s.T, key)
if err != nil {
return nil, PatchMeta{}, err
}
return PatchMetaFromStruct{T: fieldType},
PatchMeta{
patchStrategies: fieldPatchStrategies,
patchMergeKey: fieldPatchMergeKey,
}, nil
}
func (s PatchMetaFromStruct) LookupPatchMetadataForSlice(key string) (LookupPatchMeta, PatchMeta, error) {
subschema, patchMeta, err := s.LookupPatchMetadataForStruct(key)
if err != nil {
return nil, PatchMeta{}, err
}
elemPatchMetaFromStruct := subschema.(PatchMetaFromStruct)
t := elemPatchMetaFromStruct.T
var elemType reflect.Type
switch t.Kind() {
// If t is an array or a slice, get the element type.
// If element is still an array or a slice, return an error.
// Otherwise, return element type.
case reflect.Array, reflect.Slice:
elemType = t.Elem()
if elemType.Kind() == reflect.Array || elemType.Kind() == reflect.Slice {
return nil, PatchMeta{}, errors.New("unexpected slice of slice")
}
// If t is an pointer, get the underlying element.
// If the underlying element is neither an array nor a slice, the pointer is pointing to a slice,
// e.g. https://github.com/kubernetes/kubernetes/blob/bc22e206c79282487ea0bf5696d5ccec7e839a76/staging/src/k8s.io/apimachinery/pkg/util/strategicpatch/patch_test.go#L2782-L2822
// If the underlying element is either an array or a slice, return its element type.
case reflect.Ptr:
t = t.Elem()
if t.Kind() == reflect.Array || t.Kind() == reflect.Slice {
t = t.Elem()
}
elemType = t
default:
return nil, PatchMeta{}, fmt.Errorf("expected slice or array type, but got: %s", s.T.Kind().String())
}
return PatchMetaFromStruct{T: elemType}, patchMeta, nil
}
func (s PatchMetaFromStruct) Name() string {
return s.T.Kind().String()
}
func getTagStructType(dataStruct interface{}) (reflect.Type, error) {
if dataStruct == nil {
return nil, mergepatch.ErrBadArgKind(struct{}{}, nil)
}
t := reflect.TypeOf(dataStruct)
// Get the underlying type for pointers
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, mergepatch.ErrBadArgKind(struct{}{}, dataStruct)
}
return t, nil
}
func GetTagStructTypeOrDie(dataStruct interface{}) reflect.Type {
t, err := getTagStructType(dataStruct)
if err != nil {
panic(err)
}
return t
}
type PatchMetaFromOpenAPI struct {
Schema openapi.Schema
}
func NewPatchMetaFromOpenAPI(s openapi.Schema) PatchMetaFromOpenAPI {
return PatchMetaFromOpenAPI{Schema: s}
}
var _ LookupPatchMeta = PatchMetaFromOpenAPI{}
func (s PatchMetaFromOpenAPI) LookupPatchMetadataForStruct(key string) (LookupPatchMeta, PatchMeta, error) {
if s.Schema == nil {
return nil, PatchMeta{}, nil
}
kindItem := NewKindItem(key, s.Schema.GetPath())
s.Schema.Accept(kindItem)
err := kindItem.Error()
if err != nil {
return nil, PatchMeta{}, err
}
return PatchMetaFromOpenAPI{Schema: kindItem.subschema},
kindItem.patchmeta, nil
}
func (s PatchMetaFromOpenAPI) LookupPatchMetadataForSlice(key string) (LookupPatchMeta, PatchMeta, error) {
if s.Schema == nil {
return nil, PatchMeta{}, nil
}
sliceItem := NewSliceItem(key, s.Schema.GetPath())
s.Schema.Accept(sliceItem)
err := sliceItem.Error()
if err != nil {
return nil, PatchMeta{}, err
}
return PatchMetaFromOpenAPI{Schema: sliceItem.subschema},
sliceItem.patchmeta, nil
}
func (s PatchMetaFromOpenAPI) Name() string {
schema := s.Schema
return schema.GetName()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,170 +0,0 @@
{
"swagger": "2.0",
"info": {
"title": "StrategicMergePatchTestingMergeItem",
"version": "v1.9.0"
},
"paths": {},
"definitions": {
"mergeItem": {
"description": "MergeItem is type definition for testing strategic merge.",
"required": [],
"properties": {
"name": {
"description": "Name field.",
"type": "string"
},
"value": {
"description": "Value field.",
"type": "string"
},
"other": {
"description": "Other field.",
"type": "string"
},
"mergingList": {
"description": "MergingList field.",
"type": "array",
"items": {
"$ref": "#/definitions/mergeItem"
},
"x-kubernetes-patch-merge-key": "name",
"x-kubernetes-patch-strategy": "merge"
},
"nonMergingList": {
"description": "NonMergingList field.",
"type": "array",
"items": {
"$ref": "#/definitions/mergeItem"
}
},
"mergingIntList": {
"description": "MergingIntList field.",
"type": "array",
"items": {
"type": "integer",
"format": "int32"
},
"x-kubernetes-patch-strategy": "merge"
},
"nonMergingIntList": {
"description": "NonMergingIntList field.",
"type": "array",
"items": {
"type": "integer",
"format": "int32"
}
},
"mergeItemPtr": {
"description": "MergeItemPtr field.",
"$ref": "#/definitions/mergeItem",
"x-kubernetes-patch-merge-key": "name",
"x-kubernetes-patch-strategy": "merge"
},
"simpleMap": {
"description": "SimpleMap field.",
"type": "object",
"additionalProperties": {
"type": "string"
}
},
"replacingItem": {
"description": "ReplacingItem field.",
"$ref": "#/definitions/io.k8s.apimachinery.pkg.runtime.RawExtension",
"x-kubernetes-patch-strategy": "replace"
},
"retainKeysMap": {
"description": "RetainKeysMap field.",
"$ref": "#/definitions/retainKeysMergeItem",
"x-kubernetes-patch-strategy": "retainKeys"
},
"retainKeysMergingList": {
"description": "RetainKeysMergingList field.",
"type": "array",
"items": {
"$ref": "#/definitions/mergeItem"
},
"x-kubernetes-patch-merge-key": "name",
"x-kubernetes-patch-strategy": "merge,retainKeys"
}
},
"x-kubernetes-group-version-kind": [
{
"group": "fake-group",
"kind": "mergeItem",
"version": "some-version"
}
]
},
"retainKeysMergeItem": {
"description": "RetainKeysMergeItem is type definition for testing strategic merge.",
"required": [],
"properties": {
"name": {
"description": "Name field.",
"type": "string"
},
"value": {
"description": "Value field.",
"type": "string"
},
"other": {
"description": "Other field.",
"type": "string"
},
"simpleMap": {
"description": "SimpleMap field.",
"additionalProperties": "object",
"items": {
"type": "string"
}
},
"mergingList": {
"description": "MergingList field.",
"type": "array",
"items": {
"$ref": "#/definitions/mergeItem"
},
"x-kubernetes-patch-merge-key": "name",
"x-kubernetes-patch-strategy": "merge"
},
"nonMergingList": {
"description": "NonMergingList field.",
"type": "array",
"items": {
"$ref": "#/definitions/mergeItem"
}
},
"mergingIntList": {
"description": "MergingIntList field.",
"type": "array",
"items": {
"type": "integer",
"format": "int32"
},
"x-kubernetes-patch-strategy": "merge"
}
},
"x-kubernetes-group-version-kind": [
{
"group": "fake-group",
"kind": "retainKeysMergeItem",
"version": "some-version"
}
]
},
"io.k8s.apimachinery.pkg.runtime.RawExtension": {
"description": "RawExtension is used to hold extensions in external versions.",
"required": [
"Raw"
],
"properties": {
"Raw": {
"description": "Raw is the underlying serialization of this object.",
"type": "string",
"format": "byte"
}
}
}
}
}

View File

@ -1,47 +0,0 @@
{
"swagger": "2.0",
"info": {
"title": "StrategicMergePatchTestingPrecisionItem",
"version": "v1.9.0"
},
"paths": {},
"definitions": {
"precisionItem": {
"description": "PrecisionItem is type definition for testing strategic merge.",
"required": [],
"properties": {
"name": {
"description": "Name field.",
"type": "string"
},
"int32": {
"description": "Int32 field.",
"type": "integer",
"format": "int32"
},
"int64": {
"description": "Int64 field.",
"type": "integer",
"format": "int64"
},
"float32": {
"description": "Float32 field.",
"type": "number",
"format": "float32"
},
"float64": {
"description": "Float64 field.",
"type": "number",
"format": "float64"
}
},
"x-kubernetes-group-version-kind": [
{
"group": "fake-group",
"kind": "precisionItem",
"version": "some-version"
}
]
}
}
}

View File

@ -1,84 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testing
import (
"io/ioutil"
"os"
"sync"
"gopkg.in/yaml.v2"
"github.com/googleapis/gnostic/OpenAPIv2"
"github.com/googleapis/gnostic/compiler"
openapi "k8s.io/kube-openapi/pkg/util/proto"
)
// Fake opens and returns a openapi swagger from a file Path. It will
// parse only once and then return the same copy everytime.
type Fake struct {
Path string
once sync.Once
document *openapi_v2.Document
err error
}
// OpenAPISchema returns the openapi document and a potential error.
func (f *Fake) OpenAPISchema() (*openapi_v2.Document, error) {
f.once.Do(func() {
_, err := os.Stat(f.Path)
if err != nil {
f.err = err
return
}
spec, err := ioutil.ReadFile(f.Path)
if err != nil {
f.err = err
return
}
var info yaml.MapSlice
err = yaml.Unmarshal(spec, &info)
if err != nil {
f.err = err
return
}
f.document, f.err = openapi_v2.NewDocument(info, compiler.NewContext("$root", nil))
})
return f.document, f.err
}
func getSchema(f Fake, model string) (openapi.Schema, error) {
s, err := f.OpenAPISchema()
if err != nil {
return nil, err
}
m, err := openapi.NewOpenAPIData(s)
if err != nil {
return nil, err
}
return m.LookupModel(model), nil
}
// GetSchemaOrDie returns the openapi schema.
func GetSchemaOrDie(f Fake, model string) openapi.Schema {
s, err := getSchema(f, model)
if err != nil {
panic(err)
}
return s
}

View File

@ -1,193 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package strategicpatch
import (
"errors"
"strings"
"k8s.io/apimachinery/pkg/util/mergepatch"
openapi "k8s.io/kube-openapi/pkg/util/proto"
)
const (
patchStrategyOpenapiextensionKey = "x-kubernetes-patch-strategy"
patchMergeKeyOpenapiextensionKey = "x-kubernetes-patch-merge-key"
)
type LookupPatchItem interface {
openapi.SchemaVisitor
Error() error
Path() *openapi.Path
}
type kindItem struct {
key string
path *openapi.Path
err error
patchmeta PatchMeta
subschema openapi.Schema
hasVisitKind bool
}
func NewKindItem(key string, path *openapi.Path) *kindItem {
return &kindItem{
key: key,
path: path,
}
}
var _ LookupPatchItem = &kindItem{}
func (item *kindItem) Error() error {
return item.err
}
func (item *kindItem) Path() *openapi.Path {
return item.path
}
func (item *kindItem) VisitPrimitive(schema *openapi.Primitive) {
item.err = errors.New("expected kind, but got primitive")
}
func (item *kindItem) VisitArray(schema *openapi.Array) {
item.err = errors.New("expected kind, but got slice")
}
func (item *kindItem) VisitMap(schema *openapi.Map) {
item.err = errors.New("expected kind, but got map")
}
func (item *kindItem) VisitReference(schema openapi.Reference) {
if !item.hasVisitKind {
schema.SubSchema().Accept(item)
}
}
func (item *kindItem) VisitKind(schema *openapi.Kind) {
subschema, ok := schema.Fields[item.key]
if !ok {
item.err = FieldNotFoundError{Path: schema.GetPath().String(), Field: item.key}
return
}
mergeKey, patchStrategies, err := parsePatchMetadata(subschema.GetExtensions())
if err != nil {
item.err = err
return
}
item.patchmeta = PatchMeta{
patchStrategies: patchStrategies,
patchMergeKey: mergeKey,
}
item.subschema = subschema
}
type sliceItem struct {
key string
path *openapi.Path
err error
patchmeta PatchMeta
subschema openapi.Schema
hasVisitKind bool
}
func NewSliceItem(key string, path *openapi.Path) *sliceItem {
return &sliceItem{
key: key,
path: path,
}
}
var _ LookupPatchItem = &sliceItem{}
func (item *sliceItem) Error() error {
return item.err
}
func (item *sliceItem) Path() *openapi.Path {
return item.path
}
func (item *sliceItem) VisitPrimitive(schema *openapi.Primitive) {
item.err = errors.New("expected slice, but got primitive")
}
func (item *sliceItem) VisitArray(schema *openapi.Array) {
if !item.hasVisitKind {
item.err = errors.New("expected visit kind first, then visit array")
}
subschema := schema.SubType
item.subschema = subschema
}
func (item *sliceItem) VisitMap(schema *openapi.Map) {
item.err = errors.New("expected slice, but got map")
}
func (item *sliceItem) VisitReference(schema openapi.Reference) {
if !item.hasVisitKind {
schema.SubSchema().Accept(item)
} else {
item.subschema = schema.SubSchema()
}
}
func (item *sliceItem) VisitKind(schema *openapi.Kind) {
subschema, ok := schema.Fields[item.key]
if !ok {
item.err = FieldNotFoundError{Path: schema.GetPath().String(), Field: item.key}
return
}
mergeKey, patchStrategies, err := parsePatchMetadata(subschema.GetExtensions())
if err != nil {
item.err = err
return
}
item.patchmeta = PatchMeta{
patchStrategies: patchStrategies,
patchMergeKey: mergeKey,
}
item.hasVisitKind = true
subschema.Accept(item)
}
func parsePatchMetadata(extensions map[string]interface{}) (string, []string, error) {
ps, foundPS := extensions[patchStrategyOpenapiextensionKey]
var patchStrategies []string
var mergeKey, patchStrategy string
var ok bool
if foundPS {
patchStrategy, ok = ps.(string)
if ok {
patchStrategies = strings.Split(patchStrategy, ",")
} else {
return "", nil, mergepatch.ErrBadArgType(patchStrategy, ps)
}
}
mk, foundMK := extensions[patchMergeKeyOpenapiextensionKey]
if foundMK {
mergeKey, ok = mk.(string)
if !ok {
return "", nil, mergepatch.ErrBadArgType(mergeKey, mk)
}
}
return mergeKey, patchStrategies, nil
}

View File

@ -1,43 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package uuid
import (
"sync"
"github.com/pborman/uuid"
"k8s.io/apimachinery/pkg/types"
)
var uuidLock sync.Mutex
var lastUUID uuid.UUID
func NewUUID() types.UID {
uuidLock.Lock()
defer uuidLock.Unlock()
result := uuid.NewUUID()
// The UUID package is naive and can generate identical UUIDs if the
// time interval is quick enough.
// The UUID uses 100 ns increments so it's short enough to actively
// wait for a new value.
for uuid.Equal(lastUUID, result) == true {
result = uuid.NewUUID()
}
lastUUID = result
return types.UID(result.String())
}

View File

@ -1,175 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package field
import (
"fmt"
"strings"
"testing"
)
func TestMakeFuncs(t *testing.T) {
testCases := []struct {
fn func() *Error
expected ErrorType
}{
{
func() *Error { return Invalid(NewPath("f"), "v", "d") },
ErrorTypeInvalid,
},
{
func() *Error { return NotSupported(NewPath("f"), "v", nil) },
ErrorTypeNotSupported,
},
{
func() *Error { return Duplicate(NewPath("f"), "v") },
ErrorTypeDuplicate,
},
{
func() *Error { return NotFound(NewPath("f"), "v") },
ErrorTypeNotFound,
},
{
func() *Error { return Required(NewPath("f"), "d") },
ErrorTypeRequired,
},
{
func() *Error { return InternalError(NewPath("f"), fmt.Errorf("e")) },
ErrorTypeInternal,
},
}
for _, testCase := range testCases {
err := testCase.fn()
if err.Type != testCase.expected {
t.Errorf("expected Type %q, got %q", testCase.expected, err.Type)
}
}
}
func TestErrorUsefulMessage(t *testing.T) {
{
s := Invalid(nil, nil, "").Error()
t.Logf("message: %v", s)
if !strings.Contains(s, "null") {
t.Errorf("error message did not contain 'null': %s", s)
}
}
s := Invalid(NewPath("foo"), "bar", "deet").Error()
t.Logf("message: %v", s)
for _, part := range []string{"foo", "bar", "deet", ErrorTypeInvalid.String()} {
if !strings.Contains(s, part) {
t.Errorf("error message did not contain expected part '%v'", part)
}
}
type complicated struct {
Baz int
Qux string
Inner interface{}
KV map[string]int
}
s = Invalid(
NewPath("foo"),
&complicated{
Baz: 1,
Qux: "aoeu",
Inner: &complicated{Qux: "asdf"},
KV: map[string]int{"Billy": 2},
},
"detail",
).Error()
t.Logf("message: %v", s)
for _, part := range []string{
"foo", ErrorTypeInvalid.String(),
"Baz", "Qux", "Inner", "KV", "detail",
"1", "aoeu", "Billy", "2",
// "asdf", TODO: re-enable once we have a better nested printer
} {
if !strings.Contains(s, part) {
t.Errorf("error message did not contain expected part '%v'", part)
}
}
}
func TestToAggregate(t *testing.T) {
testCases := struct {
ErrList []ErrorList
NumExpectedErrs []int
}{
[]ErrorList{
nil,
{},
{Invalid(NewPath("f"), "v", "d")},
{Invalid(NewPath("f"), "v", "d"), Invalid(NewPath("f"), "v", "d")},
{Invalid(NewPath("f"), "v", "d"), InternalError(NewPath(""), fmt.Errorf("e"))},
},
[]int{
0,
0,
1,
1,
2,
},
}
if len(testCases.ErrList) != len(testCases.NumExpectedErrs) {
t.Errorf("Mismatch: length of NumExpectedErrs does not match length of ErrList")
}
for i, tc := range testCases.ErrList {
agg := tc.ToAggregate()
numErrs := 0
if agg != nil {
numErrs = len(agg.Errors())
}
if numErrs != testCases.NumExpectedErrs[i] {
t.Errorf("[%d] Expected %d, got %d", i, testCases.NumExpectedErrs[i], numErrs)
}
if len(tc) == 0 {
if agg != nil {
t.Errorf("[%d] Expected nil, got %#v", i, agg)
}
} else if agg == nil {
t.Errorf("[%d] Expected non-nil", i)
}
}
}
func TestErrListFilter(t *testing.T) {
list := ErrorList{
Invalid(NewPath("test.field"), "", ""),
Invalid(NewPath("field.test"), "", ""),
Duplicate(NewPath("test"), "value"),
}
if len(list.Filter(NewErrorTypeMatcher(ErrorTypeDuplicate))) != 2 {
t.Errorf("should not filter")
}
if len(list.Filter(NewErrorTypeMatcher(ErrorTypeInvalid))) != 1 {
t.Errorf("should filter")
}
}
func TestNotSupported(t *testing.T) {
notSupported := NotSupported(NewPath("f"), "v", []string{"a", "b", "c"})
expected := `Unsupported value: "v": supported values: "a", "b", "c"`
if notSupported.ErrorBody() != expected {
t.Errorf("Expected: %s\n, but got: %s\n", expected, notSupported.ErrorBody())
}
}

View File

@ -1,123 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package field
import "testing"
func TestPath(t *testing.T) {
testCases := []struct {
op func(*Path) *Path
expected string
}{
{
func(p *Path) *Path { return p },
"root",
},
{
func(p *Path) *Path { return p.Child("first") },
"root.first",
},
{
func(p *Path) *Path { return p.Child("second") },
"root.first.second",
},
{
func(p *Path) *Path { return p.Index(0) },
"root.first.second[0]",
},
{
func(p *Path) *Path { return p.Child("third") },
"root.first.second[0].third",
},
{
func(p *Path) *Path { return p.Index(93) },
"root.first.second[0].third[93]",
},
{
func(p *Path) *Path { return p.parent },
"root.first.second[0].third",
},
{
func(p *Path) *Path { return p.parent },
"root.first.second[0]",
},
{
func(p *Path) *Path { return p.Key("key") },
"root.first.second[0][key]",
},
}
root := NewPath("root")
p := root
for i, tc := range testCases {
p = tc.op(p)
if p.String() != tc.expected {
t.Errorf("[%d] Expected %q, got %q", i, tc.expected, p.String())
}
if p.Root() != root {
t.Errorf("[%d] Wrong root: %#v", i, p.Root())
}
}
}
func TestPathMultiArg(t *testing.T) {
testCases := []struct {
op func(*Path) *Path
expected string
}{
{
func(p *Path) *Path { return p },
"root.first",
},
{
func(p *Path) *Path { return p.Child("second", "third") },
"root.first.second.third",
},
{
func(p *Path) *Path { return p.Index(0) },
"root.first.second.third[0]",
},
{
func(p *Path) *Path { return p.parent },
"root.first.second.third",
},
{
func(p *Path) *Path { return p.parent },
"root.first.second",
},
{
func(p *Path) *Path { return p.parent },
"root.first",
},
{
func(p *Path) *Path { return p.parent },
"root",
},
}
root := NewPath("root", "first")
p := root
for i, tc := range testCases {
p = tc.op(p)
if p.String() != tc.expected {
t.Errorf("[%d] Expected %q, got %q", i, tc.expected, p.String())
}
if p.Root() != root.Root() {
t.Errorf("[%d] Wrong root: %#v", i, p.Root())
}
}
}

View File

@ -1,541 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validation
import (
"strings"
"testing"
"k8s.io/apimachinery/pkg/util/validation/field"
)
func TestIsDNS1123Label(t *testing.T) {
goodValues := []string{
"a", "ab", "abc", "a1", "a-1", "a--1--2--b",
"0", "01", "012", "1a", "1-a", "1--a--b--2",
strings.Repeat("a", 63),
}
for _, val := range goodValues {
if msgs := IsDNS1123Label(val); len(msgs) != 0 {
t.Errorf("expected true for '%s': %v", val, msgs)
}
}
badValues := []string{
"", "A", "ABC", "aBc", "A1", "A-1", "1-A",
"-", "a-", "-a", "1-", "-1",
"_", "a_", "_a", "a_b", "1_", "_1", "1_2",
".", "a.", ".a", "a.b", "1.", ".1", "1.2",
" ", "a ", " a", "a b", "1 ", " 1", "1 2",
strings.Repeat("a", 64),
}
for _, val := range badValues {
if msgs := IsDNS1123Label(val); len(msgs) == 0 {
t.Errorf("expected false for '%s'", val)
}
}
}
func TestIsDNS1123Subdomain(t *testing.T) {
goodValues := []string{
"a", "ab", "abc", "a1", "a-1", "a--1--2--b",
"0", "01", "012", "1a", "1-a", "1--a--b--2",
"a.a", "ab.a", "abc.a", "a1.a", "a-1.a", "a--1--2--b.a",
"a.1", "ab.1", "abc.1", "a1.1", "a-1.1", "a--1--2--b.1",
"0.a", "01.a", "012.a", "1a.a", "1-a.a", "1--a--b--2",
"0.1", "01.1", "012.1", "1a.1", "1-a.1", "1--a--b--2.1",
"a.b.c.d.e", "aa.bb.cc.dd.ee", "1.2.3.4.5", "11.22.33.44.55",
strings.Repeat("a", 253),
}
for _, val := range goodValues {
if msgs := IsDNS1123Subdomain(val); len(msgs) != 0 {
t.Errorf("expected true for '%s': %v", val, msgs)
}
}
badValues := []string{
"", "A", "ABC", "aBc", "A1", "A-1", "1-A",
"-", "a-", "-a", "1-", "-1",
"_", "a_", "_a", "a_b", "1_", "_1", "1_2",
".", "a.", ".a", "a..b", "1.", ".1", "1..2",
" ", "a ", " a", "a b", "1 ", " 1", "1 2",
"A.a", "aB.a", "ab.A", "A1.a", "a1.A",
"A.1", "aB.1", "A1.1", "1A.1",
"0.A", "01.A", "012.A", "1A.a", "1a.A",
"A.B.C.D.E", "AA.BB.CC.DD.EE", "a.B.c.d.e", "aa.bB.cc.dd.ee",
"a@b", "a,b", "a_b", "a;b",
"a:b", "a%b", "a?b", "a$b",
strings.Repeat("a", 254),
}
for _, val := range badValues {
if msgs := IsDNS1123Subdomain(val); len(msgs) == 0 {
t.Errorf("expected false for '%s'", val)
}
}
}
func TestIsDNS1035Label(t *testing.T) {
goodValues := []string{
"a", "ab", "abc", "a1", "a-1", "a--1--2--b",
strings.Repeat("a", 63),
}
for _, val := range goodValues {
if msgs := IsDNS1035Label(val); len(msgs) != 0 {
t.Errorf("expected true for '%s': %v", val, msgs)
}
}
badValues := []string{
"0", "01", "012", "1a", "1-a", "1--a--b--2",
"", "A", "ABC", "aBc", "A1", "A-1", "1-A",
"-", "a-", "-a", "1-", "-1",
"_", "a_", "_a", "a_b", "1_", "_1", "1_2",
".", "a.", ".a", "a.b", "1.", ".1", "1.2",
" ", "a ", " a", "a b", "1 ", " 1", "1 2",
strings.Repeat("a", 64),
}
for _, val := range badValues {
if msgs := IsDNS1035Label(val); len(msgs) == 0 {
t.Errorf("expected false for '%s'", val)
}
}
}
func TestIsCIdentifier(t *testing.T) {
goodValues := []string{
"a", "ab", "abc", "a1", "_a", "a_", "a_b", "a_1", "a__1__2__b", "__abc_123",
"A", "AB", "AbC", "A1", "_A", "A_", "A_B", "A_1", "A__1__2__B", "__123_ABC",
}
for _, val := range goodValues {
if msgs := IsCIdentifier(val); len(msgs) != 0 {
t.Errorf("expected true for '%s': %v", val, msgs)
}
}
badValues := []string{
"", "1", "123", "1a",
"-", "a-", "-a", "1-", "-1", "1_", "1_2",
".", "a.", ".a", "a.b", "1.", ".1", "1.2",
" ", "a ", " a", "a b", "1 ", " 1", "1 2",
"#a#",
}
for _, val := range badValues {
if msgs := IsCIdentifier(val); len(msgs) == 0 {
t.Errorf("expected false for '%s'", val)
}
}
}
func TestIsValidPortNum(t *testing.T) {
goodValues := []int{1, 2, 1000, 16384, 32768, 65535}
for _, val := range goodValues {
if msgs := IsValidPortNum(val); len(msgs) != 0 {
t.Errorf("expected true for %d, got %v", val, msgs)
}
}
badValues := []int{0, -1, 65536, 100000}
for _, val := range badValues {
if msgs := IsValidPortNum(val); len(msgs) == 0 {
t.Errorf("expected false for %d", val)
}
}
}
func TestIsInRange(t *testing.T) {
goodValues := []struct {
value int
min int
max int
}{{1, 0, 10}, {5, 5, 20}, {25, 10, 25}}
for _, val := range goodValues {
if msgs := IsInRange(val.value, val.min, val.max); len(msgs) > 0 {
t.Errorf("expected no errors for %#v, but got %v", val, msgs)
}
}
badValues := []struct {
value int
min int
max int
}{{1, 2, 10}, {5, -4, 2}, {25, 100, 120}}
for _, val := range badValues {
if msgs := IsInRange(val.value, val.min, val.max); len(msgs) == 0 {
t.Errorf("expected errors for %#v", val)
}
}
}
func createGroupIDs(ids ...int64) []int64 {
var output []int64
for _, id := range ids {
output = append(output, int64(id))
}
return output
}
func createUserIDs(ids ...int64) []int64 {
var output []int64
for _, id := range ids {
output = append(output, int64(id))
}
return output
}
func TestIsValidGroupID(t *testing.T) {
goodValues := createGroupIDs(0, 1, 1000, 65535, 2147483647)
for _, val := range goodValues {
if msgs := IsValidGroupID(val); len(msgs) != 0 {
t.Errorf("expected true for '%d': %v", val, msgs)
}
}
badValues := createGroupIDs(-1, -1003, 2147483648, 4147483647)
for _, val := range badValues {
if msgs := IsValidGroupID(val); len(msgs) == 0 {
t.Errorf("expected false for '%d'", val)
}
}
}
func TestIsValidUserID(t *testing.T) {
goodValues := createUserIDs(0, 1, 1000, 65535, 2147483647)
for _, val := range goodValues {
if msgs := IsValidUserID(val); len(msgs) != 0 {
t.Errorf("expected true for '%d': %v", val, msgs)
}
}
badValues := createUserIDs(-1, -1003, 2147483648, 4147483647)
for _, val := range badValues {
if msgs := IsValidUserID(val); len(msgs) == 0 {
t.Errorf("expected false for '%d'", val)
}
}
}
func TestIsValidPortName(t *testing.T) {
goodValues := []string{"telnet", "re-mail-ck", "pop3", "a", "a-1", "1-a", "a-1-b-2-c", "1-a-2-b-3"}
for _, val := range goodValues {
if msgs := IsValidPortName(val); len(msgs) != 0 {
t.Errorf("expected true for %q: %v", val, msgs)
}
}
badValues := []string{"longerthan15characters", "", strings.Repeat("a", 16), "12345", "1-2-3-4", "-begin", "end-", "two--hyphens", "whois++"}
for _, val := range badValues {
if msgs := IsValidPortName(val); len(msgs) == 0 {
t.Errorf("expected false for %q", val)
}
}
}
func TestIsQualifiedName(t *testing.T) {
successCases := []string{
"simple",
"now-with-dashes",
"1-starts-with-num",
"1234",
"simple/simple",
"now-with-dashes/simple",
"now-with-dashes/now-with-dashes",
"now.with.dots/simple",
"now-with.dashes-and.dots/simple",
"1-num.2-num/3-num",
"1234/5678",
"1.2.3.4/5678",
"Uppercase_Is_OK_123",
"example.com/Uppercase_Is_OK_123",
"requests.storage-foo",
strings.Repeat("a", 63),
strings.Repeat("a", 253) + "/" + strings.Repeat("b", 63),
}
for i := range successCases {
if errs := IsQualifiedName(successCases[i]); len(errs) != 0 {
t.Errorf("case[%d]: %q: expected success: %v", i, successCases[i], errs)
}
}
errorCases := []string{
"nospecialchars%^=@",
"cantendwithadash-",
"-cantstartwithadash-",
"only/one/slash",
"Example.com/abc",
"example_com/abc",
"example.com/",
"/simple",
strings.Repeat("a", 64),
strings.Repeat("a", 254) + "/abc",
}
for i := range errorCases {
if errs := IsQualifiedName(errorCases[i]); len(errs) == 0 {
t.Errorf("case[%d]: %q: expected failure", i, errorCases[i])
}
}
}
func TestIsValidLabelValue(t *testing.T) {
successCases := []string{
"simple",
"now-with-dashes",
"1-starts-with-num",
"end-with-num-1",
"1234", // only num
strings.Repeat("a", 63), // to the limit
"", // empty value
}
for i := range successCases {
if errs := IsValidLabelValue(successCases[i]); len(errs) != 0 {
t.Errorf("case %s expected success: %v", successCases[i], errs)
}
}
errorCases := []string{
"nospecialchars%^=@",
"Tama-nui-te-rā.is.Māori.sun",
"\\backslashes\\are\\bad",
"-starts-with-dash",
"ends-with-dash-",
".starts.with.dot",
"ends.with.dot.",
strings.Repeat("a", 64), // over the limit
}
for i := range errorCases {
if errs := IsValidLabelValue(errorCases[i]); len(errs) == 0 {
t.Errorf("case[%d] expected failure", i)
}
}
}
func TestIsValidIP(t *testing.T) {
goodValues := []string{
"::1",
"2a00:79e0:2:0:f1c3:e797:93c1:df80",
"::",
"2001:4860:4860::8888",
"::fff:1.1.1.1",
"1.1.1.1",
"1.1.1.01",
"255.0.0.1",
"1.0.0.0",
"0.0.0.0",
}
for _, val := range goodValues {
if msgs := IsValidIP(val); len(msgs) != 0 {
t.Errorf("expected true for %q: %v", val, msgs)
}
}
badValues := []string{
"[2001:db8:0:1]:80",
"myhost.mydomain",
"-1.0.0.0",
"[2001:db8:0:1]",
"a",
}
for _, val := range badValues {
if msgs := IsValidIP(val); len(msgs) == 0 {
t.Errorf("expected false for %q", val)
}
}
}
func TestIsHTTPHeaderName(t *testing.T) {
goodValues := []string{
// Common ones
"Accept-Encoding", "Host", "If-Modified-Since", "X-Forwarded-For",
// Weirdo, but still conforming names
"a", "ab", "abc", "a1", "-a", "a-", "a-b", "a-1", "a--1--2--b", "--abc-123",
"A", "AB", "AbC", "A1", "-A", "A-", "A-B", "A-1", "A--1--2--B", "--123-ABC",
}
for _, val := range goodValues {
if msgs := IsHTTPHeaderName(val); len(msgs) != 0 {
t.Errorf("expected true for '%s': %v", val, msgs)
}
}
badValues := []string{
"Host:", "X-Forwarded-For:", "X-@Home",
"", "_", "a_", "_a", "1_", "1_2", ".", "a.", ".a", "a.b", "1.", ".1", "1.2",
" ", "a ", " a", "a b", "1 ", " 1", "1 2", "#a#", "^", ",", ";", "=", "<",
"?", "@", "{",
}
for _, val := range badValues {
if msgs := IsHTTPHeaderName(val); len(msgs) == 0 {
t.Errorf("expected false for '%s'", val)
}
}
}
func TestIsValidPercent(t *testing.T) {
goodValues := []string{
"0%",
"00000%",
"1%",
"01%",
"99%",
"100%",
"101%",
}
for _, val := range goodValues {
if msgs := IsValidPercent(val); len(msgs) != 0 {
t.Errorf("expected true for %q: %v", val, msgs)
}
}
badValues := []string{
"",
"0",
"100",
"0.0%",
"99.9%",
"hundred",
" 1%",
"1% ",
"-0%",
"-1%",
"+1%",
}
for _, val := range badValues {
if msgs := IsValidPercent(val); len(msgs) == 0 {
t.Errorf("expected false for %q", val)
}
}
}
func TestIsConfigMapKey(t *testing.T) {
successCases := []string{
"a",
"good",
"good-good",
"still.good",
"this.is.also.good",
".so.is.this",
"THIS_IS_GOOD",
"so_is_this_17",
}
for i := range successCases {
if errs := IsConfigMapKey(successCases[i]); len(errs) != 0 {
t.Errorf("[%d] expected success: %v", i, errs)
}
}
failureCases := []string{
".",
"..",
"..bad",
"b*d",
"bad!&bad",
}
for i := range failureCases {
if errs := IsConfigMapKey(failureCases[i]); len(errs) == 0 {
t.Errorf("[%d] expected failure", i)
}
}
}
func TestIsWildcardDNS1123Subdomain(t *testing.T) {
goodValues := []string{
"*.example.com",
"*.bar.com",
"*.foo.bar.com",
}
for _, val := range goodValues {
if errs := IsWildcardDNS1123Subdomain(val); len(errs) != 0 {
t.Errorf("expected no errors for %q: %v", val, errs)
}
}
badValues := []string{
"*.*.bar.com",
"*.foo.*.com",
"*bar.com",
"f*.bar.com",
"*",
}
for _, val := range badValues {
if errs := IsWildcardDNS1123Subdomain(val); len(errs) == 0 {
t.Errorf("expected errors for %q", val)
}
}
}
func TestIsFullyQualifiedName(t *testing.T) {
tests := []struct {
name string
targetName string
err string
}{
{
name: "name needs to be fully qualified, i.e., contains at least 2 dots",
targetName: "k8s.io",
err: "should be a domain with at least three segments separated by dots",
},
{
name: "name cannot be empty",
targetName: "",
err: "Required value",
},
{
name: "name must conform to RFC 1123",
targetName: "A.B.C",
err: "a DNS-1123 subdomain must consist of lower case alphanumeric characters",
},
}
for _, tc := range tests {
err := IsFullyQualifiedName(field.NewPath(""), tc.targetName).ToAggregate()
switch {
case tc.err == "" && err != nil:
t.Errorf("%q: unexpected error: %v", tc.name, err)
case tc.err != "" && err == nil:
t.Errorf("%q: unexpected no error, expected %s", tc.name, tc.err)
case tc.err != "" && err != nil && !strings.Contains(err.Error(), tc.err):
t.Errorf("%q: expected %s, got %v", tc.name, tc.err, err)
}
}
}
func TestIsValidSocketAddr(t *testing.T) {
goodValues := []string{
"0.0.0.0:10254",
"127.0.0.1:8888",
"[2001:db8:1f70::999:de8:7648:6e8]:10254",
"[::]:10254",
}
for _, val := range goodValues {
if errs := IsValidSocketAddr(val); len(errs) != 0 {
t.Errorf("expected no errors for %q: %v", val, errs)
}
}
badValues := []string{
"0.0.0.0.0:2020",
"0.0.0.0",
"6.6.6.6:909090",
"2001:db8:1f70::999:de8:7648:6e8:87567:102545",
"",
"*",
}
for _, val := range badValues {
if errs := IsValidSocketAddr(val); len(errs) == 0 {
t.Errorf("expected errors for %q", val)
}
}
}

View File

@ -1,18 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package version provides utilities for version number comparisons
package version // import "k8s.io/apimachinery/pkg/util/version"

View File

@ -1,285 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package version
import (
"bytes"
"fmt"
"regexp"
"strconv"
"strings"
)
// Version is an opqaue representation of a version number
type Version struct {
components []uint
semver bool
preRelease string
buildMetadata string
}
var (
// versionMatchRE splits a version string into numeric and "extra" parts
versionMatchRE = regexp.MustCompile(`^\s*v?([0-9]+(?:\.[0-9]+)*)(.*)*$`)
// extraMatchRE splits the "extra" part of versionMatchRE into semver pre-release and build metadata; it does not validate the "no leading zeroes" constraint for pre-release
extraMatchRE = regexp.MustCompile(`^(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?\s*$`)
)
func parse(str string, semver bool) (*Version, error) {
parts := versionMatchRE.FindStringSubmatch(str)
if parts == nil {
return nil, fmt.Errorf("could not parse %q as version", str)
}
numbers, extra := parts[1], parts[2]
components := strings.Split(numbers, ".")
if (semver && len(components) != 3) || (!semver && len(components) < 2) {
return nil, fmt.Errorf("illegal version string %q", str)
}
v := &Version{
components: make([]uint, len(components)),
semver: semver,
}
for i, comp := range components {
if (i == 0 || semver) && strings.HasPrefix(comp, "0") && comp != "0" {
return nil, fmt.Errorf("illegal zero-prefixed version component %q in %q", comp, str)
}
num, err := strconv.ParseUint(comp, 10, 0)
if err != nil {
return nil, fmt.Errorf("illegal non-numeric version component %q in %q: %v", comp, str, err)
}
v.components[i] = uint(num)
}
if semver && extra != "" {
extraParts := extraMatchRE.FindStringSubmatch(extra)
if extraParts == nil {
return nil, fmt.Errorf("could not parse pre-release/metadata (%s) in version %q", extra, str)
}
v.preRelease, v.buildMetadata = extraParts[1], extraParts[2]
for _, comp := range strings.Split(v.preRelease, ".") {
if _, err := strconv.ParseUint(comp, 10, 0); err == nil {
if strings.HasPrefix(comp, "0") && comp != "0" {
return nil, fmt.Errorf("illegal zero-prefixed version component %q in %q", comp, str)
}
}
}
}
return v, nil
}
// ParseGeneric parses a "generic" version string. The version string must consist of two
// or more dot-separated numeric fields (the first of which can't have leading zeroes),
// followed by arbitrary uninterpreted data (which need not be separated from the final
// numeric field by punctuation). For convenience, leading and trailing whitespace is
// ignored, and the version can be preceded by the letter "v". See also ParseSemantic.
func ParseGeneric(str string) (*Version, error) {
return parse(str, false)
}
// MustParseGeneric is like ParseGeneric except that it panics on error
func MustParseGeneric(str string) *Version {
v, err := ParseGeneric(str)
if err != nil {
panic(err)
}
return v
}
// ParseSemantic parses a version string that exactly obeys the syntax and semantics of
// the "Semantic Versioning" specification (http://semver.org/) (although it ignores
// leading and trailing whitespace, and allows the version to be preceded by "v"). For
// version strings that are not guaranteed to obey the Semantic Versioning syntax, use
// ParseGeneric.
func ParseSemantic(str string) (*Version, error) {
return parse(str, true)
}
// MustParseSemantic is like ParseSemantic except that it panics on error
func MustParseSemantic(str string) *Version {
v, err := ParseSemantic(str)
if err != nil {
panic(err)
}
return v
}
// Major returns the major release number
func (v *Version) Major() uint {
return v.components[0]
}
// Minor returns the minor release number
func (v *Version) Minor() uint {
return v.components[1]
}
// Patch returns the patch release number if v is a Semantic Version, or 0
func (v *Version) Patch() uint {
if len(v.components) < 3 {
return 0
}
return v.components[2]
}
// BuildMetadata returns the build metadata, if v is a Semantic Version, or ""
func (v *Version) BuildMetadata() string {
return v.buildMetadata
}
// PreRelease returns the prerelease metadata, if v is a Semantic Version, or ""
func (v *Version) PreRelease() string {
return v.preRelease
}
// Components returns the version number components
func (v *Version) Components() []uint {
return v.components
}
// String converts a Version back to a string; note that for versions parsed with
// ParseGeneric, this will not include the trailing uninterpreted portion of the version
// number.
func (v *Version) String() string {
var buffer bytes.Buffer
for i, comp := range v.components {
if i > 0 {
buffer.WriteString(".")
}
buffer.WriteString(fmt.Sprintf("%d", comp))
}
if v.preRelease != "" {
buffer.WriteString("-")
buffer.WriteString(v.preRelease)
}
if v.buildMetadata != "" {
buffer.WriteString("+")
buffer.WriteString(v.buildMetadata)
}
return buffer.String()
}
// compareInternal returns -1 if v is less than other, 1 if it is greater than other, or 0
// if they are equal
func (v *Version) compareInternal(other *Version) int {
vLen := len(v.components)
oLen := len(other.components)
for i := 0; i < vLen && i < oLen; i++ {
switch {
case other.components[i] < v.components[i]:
return 1
case other.components[i] > v.components[i]:
return -1
}
}
// If components are common but one has more items and they are not zeros, it is bigger
switch {
case oLen < vLen && !onlyZeros(v.components[oLen:]):
return 1
case oLen > vLen && !onlyZeros(other.components[vLen:]):
return -1
}
if !v.semver || !other.semver {
return 0
}
switch {
case v.preRelease == "" && other.preRelease != "":
return 1
case v.preRelease != "" && other.preRelease == "":
return -1
case v.preRelease == other.preRelease: // includes case where both are ""
return 0
}
vPR := strings.Split(v.preRelease, ".")
oPR := strings.Split(other.preRelease, ".")
for i := 0; i < len(vPR) && i < len(oPR); i++ {
vNum, err := strconv.ParseUint(vPR[i], 10, 0)
if err == nil {
oNum, err := strconv.ParseUint(oPR[i], 10, 0)
if err == nil {
switch {
case oNum < vNum:
return 1
case oNum > vNum:
return -1
default:
continue
}
}
}
if oPR[i] < vPR[i] {
return 1
} else if oPR[i] > vPR[i] {
return -1
}
}
switch {
case len(oPR) < len(vPR):
return 1
case len(oPR) > len(vPR):
return -1
}
return 0
}
// returns false if array contain any non-zero element
func onlyZeros(array []uint) bool {
for _, num := range array {
if num != 0 {
return false
}
}
return true
}
// AtLeast tests if a version is at least equal to a given minimum version. If both
// Versions are Semantic Versions, this will use the Semantic Version comparison
// algorithm. Otherwise, it will compare only the numeric components, with non-present
// components being considered "0" (ie, "1.4" is equal to "1.4.0").
func (v *Version) AtLeast(min *Version) bool {
return v.compareInternal(min) != -1
}
// LessThan tests if a version is less than a given version. (It is exactly the opposite
// of AtLeast, for situations where asking "is v too old?" makes more sense than asking
// "is v new enough?".)
func (v *Version) LessThan(other *Version) bool {
return v.compareInternal(other) == -1
}
// Compare compares v against a version string (which will be parsed as either Semantic
// or non-Semantic depending on v). On success it returns -1 if v is less than other, 1 if
// it is greater than other, or 0 if they are equal.
func (v *Version) Compare(other string) (int, error) {
ov, err := parse(other, v.semver)
if err != nil {
return 0, err
}
return v.compareInternal(ov), nil
}

View File

@ -1,348 +0,0 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package version
import (
"fmt"
"reflect"
"testing"
)
type testItem struct {
version string
unparsed string
equalsPrev bool
}
func testOne(v *Version, item, prev testItem) error {
str := v.String()
if item.unparsed == "" {
if str != item.version {
return fmt.Errorf("bad round-trip: %q -> %q", item.version, str)
}
} else {
if str != item.unparsed {
return fmt.Errorf("bad unparse: %q -> %q, expected %q", item.version, str, item.unparsed)
}
}
if prev.version != "" {
cmp, err := v.Compare(prev.version)
if err != nil {
return fmt.Errorf("unexpected parse error: %v", err)
}
rv, err := parse(prev.version, v.semver)
if err != nil {
return fmt.Errorf("unexpected parse error: %v", err)
}
rcmp, err := rv.Compare(item.version)
if err != nil {
return fmt.Errorf("unexpected parse error: %v", err)
}
switch {
case cmp == -1:
return fmt.Errorf("unexpected ordering %q < %q", item.version, prev.version)
case cmp == 0 && !item.equalsPrev:
return fmt.Errorf("unexpected comparison %q == %q", item.version, prev.version)
case cmp == 1 && item.equalsPrev:
return fmt.Errorf("unexpected comparison %q != %q", item.version, prev.version)
case cmp != -rcmp:
return fmt.Errorf("unexpected reverse comparison %q <=> %q %v %v %v %v", item.version, prev.version, cmp, rcmp, v.Components(), rv.Components())
}
}
return nil
}
func TestSemanticVersions(t *testing.T) {
tests := []testItem{
// This is every version string that appears in the 2.0 semver spec,
// sorted in strictly increasing order except as noted.
{version: "0.1.0"},
{version: "1.0.0-0.3.7"},
{version: "1.0.0-alpha"},
{version: "1.0.0-alpha+001", equalsPrev: true},
{version: "1.0.0-alpha.1"},
{version: "1.0.0-alpha.beta"},
{version: "1.0.0-beta"},
{version: "1.0.0-beta+exp.sha.5114f85", equalsPrev: true},
{version: "1.0.0-beta.2"},
{version: "1.0.0-beta.11"},
{version: "1.0.0-rc.1"},
{version: "1.0.0-x.7.z.92"},
{version: "1.0.0"},
{version: "1.0.0+20130313144700", equalsPrev: true},
{version: "1.8.0-alpha.3"},
{version: "1.8.0-alpha.3.673+73326ef01d2d7c"},
{version: "1.9.0"},
{version: "1.10.0"},
{version: "1.11.0"},
{version: "2.0.0"},
{version: "2.1.0"},
{version: "2.1.1"},
{version: "42.0.0"},
// We also allow whitespace and "v" prefix
{version: " 42.0.0", unparsed: "42.0.0", equalsPrev: true},
{version: "\t42.0.0 ", unparsed: "42.0.0", equalsPrev: true},
{version: "43.0.0-1", unparsed: "43.0.0-1"},
{version: "43.0.0-1 ", unparsed: "43.0.0-1", equalsPrev: true},
{version: "v43.0.0-1", unparsed: "43.0.0-1", equalsPrev: true},
{version: " v43.0.0", unparsed: "43.0.0"},
{version: " 43.0.0 ", unparsed: "43.0.0", equalsPrev: true},
}
var prev testItem
for _, item := range tests {
v, err := ParseSemantic(item.version)
if err != nil {
t.Errorf("unexpected parse error: %v", err)
continue
}
err = testOne(v, item, prev)
if err != nil {
t.Errorf("%v", err)
}
prev = item
}
}
func TestBadSemanticVersions(t *testing.T) {
tests := []string{
// "MUST take the form X.Y.Z"
"1",
"1.2",
"1.2.3.4",
".2.3",
"1..3",
"1.2.",
"",
"..",
// "where X, Y, and Z are non-negative integers"
"-1.2.3",
"1.-2.3",
"1.2.-3",
"1a.2.3",
"1.2a.3",
"1.2.3a",
"a1.2.3",
"a.b.c",
"1 .2.3",
"1. 2.3",
// "and MUST NOT contain leading zeroes."
"01.2.3",
"1.02.3",
"1.2.03",
// "[pre-release] identifiers MUST comprise only ASCII alphanumerics and hyphen"
"1.2.3-/",
// "[pre-release] identifiers MUST NOT be empty"
"1.2.3-",
"1.2.3-.",
"1.2.3-foo.",
"1.2.3-.foo",
// "Numeric [pre-release] identifiers MUST NOT include leading zeroes"
"1.2.3-01",
// "[build metadata] identifiers MUST comprise only ASCII alphanumerics and hyphen"
"1.2.3+/",
// "[build metadata] identifiers MUST NOT be empty"
"1.2.3+",
"1.2.3+.",
"1.2.3+foo.",
"1.2.3+.foo",
// whitespace/"v"-prefix checks
"v 1.2.3",
"vv1.2.3",
}
for i := range tests {
_, err := ParseSemantic(tests[i])
if err == nil {
t.Errorf("unexpected success parsing invalid semver %q", tests[i])
}
}
}
func TestGenericVersions(t *testing.T) {
tests := []testItem{
// This is all of the strings from TestSemanticVersions, plus some strings
// from TestBadSemanticVersions that should parse as generic versions,
// plus some additional strings.
{version: "0.1.0", unparsed: "0.1.0"},
{version: "1.0.0-0.3.7", unparsed: "1.0.0"},
{version: "1.0.0-alpha", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0-alpha+001", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0-alpha.1", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0-alpha.beta", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0.beta", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0-beta+exp.sha.5114f85", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0.beta.2", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0.beta.11", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0.rc.1", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0-x.7.z.92", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0", unparsed: "1.0.0", equalsPrev: true},
{version: "1.0.0+20130313144700", unparsed: "1.0.0", equalsPrev: true},
{version: "1.2", unparsed: "1.2"},
{version: "1.2a.3", unparsed: "1.2", equalsPrev: true},
{version: "1.2.3", unparsed: "1.2.3"},
{version: "1.2.3.0", unparsed: "1.2.3.0", equalsPrev: true},
{version: "1.2.3a", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3-foo.", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3-.foo", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3-01", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3+", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3+foo.", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3+.foo", unparsed: "1.2.3", equalsPrev: true},
{version: "1.02.3", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.03", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.003", unparsed: "1.2.3", equalsPrev: true},
{version: "1.2.3.4", unparsed: "1.2.3.4"},
{version: "1.2.3.4b3", unparsed: "1.2.3.4", equalsPrev: true},
{version: "1.2.3.4.5", unparsed: "1.2.3.4.5"},
{version: "1.9.0", unparsed: "1.9.0"},
{version: "1.9.0.0.0.0.0.0", unparsed: "1.9.0.0.0.0.0.0", equalsPrev: true},
{version: "1.10.0", unparsed: "1.10.0"},
{version: "1.11.0", unparsed: "1.11.0"},
{version: "1.11.0.0.5", unparsed: "1.11.0.0.5"},
{version: "2.0.0", unparsed: "2.0.0"},
{version: "2.1.0", unparsed: "2.1.0"},
{version: "2.1.1", unparsed: "2.1.1"},
{version: "42.0.0", unparsed: "42.0.0"},
{version: " 42.0.0", unparsed: "42.0.0", equalsPrev: true},
{version: "\t42.0.0 ", unparsed: "42.0.0", equalsPrev: true},
{version: "42.0.0-1", unparsed: "42.0.0", equalsPrev: true},
{version: "42.0.0-1 ", unparsed: "42.0.0", equalsPrev: true},
{version: "v42.0.0-1", unparsed: "42.0.0", equalsPrev: true},
{version: " v43.0.0", unparsed: "43.0.0"},
{version: " 43.0.0 ", unparsed: "43.0.0", equalsPrev: true},
}
var prev testItem
for _, item := range tests {
v, err := ParseGeneric(item.version)
if err != nil {
t.Errorf("unexpected parse error: %v", err)
continue
}
err = testOne(v, item, prev)
if err != nil {
t.Errorf("%v", err)
}
prev = item
}
}
func TestBadGenericVersions(t *testing.T) {
tests := []string{
"1",
"01.2.3",
"-1.2.3",
"1.-2.3",
".2.3",
"1..3",
"1a.2.3",
"a1.2.3",
"1 .2.3",
"1. 2.3",
"1.bob",
"bob",
"v 1.2.3",
"vv1.2.3",
"",
".",
}
for i := range tests {
_, err := ParseGeneric(tests[i])
if err == nil {
t.Errorf("unexpected success parsing invalid version %q", tests[i])
}
}
}
func TestComponents(t *testing.T) {
var tests = []struct {
version string
semver bool
expectedComponents []uint
expectedMajor uint
expectedMinor uint
expectedPatch uint
expectedPreRelease string
expectedBuildMetadata string
}{
{
version: "1.0.2",
semver: true,
expectedComponents: []uint{1, 0, 2},
expectedMajor: 1,
expectedMinor: 0,
expectedPatch: 2,
},
{
version: "1.0.2-alpha+001",
semver: true,
expectedComponents: []uint{1, 0, 2},
expectedMajor: 1,
expectedMinor: 0,
expectedPatch: 2,
expectedPreRelease: "alpha",
expectedBuildMetadata: "001",
},
{
version: "1.2",
semver: false,
expectedComponents: []uint{1, 2},
expectedMajor: 1,
expectedMinor: 2,
},
{
version: "1.0.2-beta+exp.sha.5114f85",
semver: true,
expectedComponents: []uint{1, 0, 2},
expectedMajor: 1,
expectedMinor: 0,
expectedPatch: 2,
expectedPreRelease: "beta",
expectedBuildMetadata: "exp.sha.5114f85",
},
}
for _, test := range tests {
version, _ := parse(test.version, test.semver)
if !reflect.DeepEqual(test.expectedComponents, version.Components()) {
t.Error("parse returned un'expected components")
}
if test.expectedMajor != version.Major() {
t.Errorf("parse returned version.Major %d, expected %d", test.expectedMajor, version.Major())
}
if test.expectedMinor != version.Minor() {
t.Errorf("parse returned version.Minor %d, expected %d", test.expectedMinor, version.Minor())
}
if test.expectedPatch != version.Patch() {
t.Errorf("parse returned version.Patch %d, expected %d", test.expectedPatch, version.Patch())
}
if test.expectedPreRelease != version.PreRelease() {
t.Errorf("parse returned version.PreRelease %s, expected %s", test.expectedPreRelease, version.PreRelease())
}
if test.expectedBuildMetadata != version.BuildMetadata() {
t.Errorf("parse returned version.BuildMetadata %s, expected %s", test.expectedBuildMetadata, version.BuildMetadata())
}
}
}

View File

@ -1,501 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package wait
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"k8s.io/apimachinery/pkg/util/runtime"
)
func TestUntil(t *testing.T) {
ch := make(chan struct{})
close(ch)
Until(func() {
t.Fatal("should not have been invoked")
}, 0, ch)
ch = make(chan struct{})
called := make(chan struct{})
go func() {
Until(func() {
called <- struct{}{}
}, 0, ch)
close(called)
}()
<-called
close(ch)
<-called
}
func TestNonSlidingUntil(t *testing.T) {
ch := make(chan struct{})
close(ch)
NonSlidingUntil(func() {
t.Fatal("should not have been invoked")
}, 0, ch)
ch = make(chan struct{})
called := make(chan struct{})
go func() {
NonSlidingUntil(func() {
called <- struct{}{}
}, 0, ch)
close(called)
}()
<-called
close(ch)
<-called
}
func TestUntilReturnsImmediately(t *testing.T) {
now := time.Now()
ch := make(chan struct{})
Until(func() {
close(ch)
}, 30*time.Second, ch)
if now.Add(25 * time.Second).Before(time.Now()) {
t.Errorf("Until did not return immediately when the stop chan was closed inside the func")
}
}
func TestJitterUntil(t *testing.T) {
ch := make(chan struct{})
// if a channel is closed JitterUntil never calls function f
// and returns immediately
close(ch)
JitterUntil(func() {
t.Fatal("should not have been invoked")
}, 0, 1.0, true, ch)
ch = make(chan struct{})
called := make(chan struct{})
go func() {
JitterUntil(func() {
called <- struct{}{}
}, 0, 1.0, true, ch)
close(called)
}()
<-called
close(ch)
<-called
}
func TestJitterUntilReturnsImmediately(t *testing.T) {
now := time.Now()
ch := make(chan struct{})
JitterUntil(func() {
close(ch)
}, 30*time.Second, 1.0, true, ch)
if now.Add(25 * time.Second).Before(time.Now()) {
t.Errorf("JitterUntil did not return immediately when the stop chan was closed inside the func")
}
}
func TestJitterUntilRecoversPanic(t *testing.T) {
// Save and restore crash handlers
originalReallyCrash := runtime.ReallyCrash
originalHandlers := runtime.PanicHandlers
defer func() {
runtime.ReallyCrash = originalReallyCrash
runtime.PanicHandlers = originalHandlers
}()
called := 0
handled := 0
// Hook up a custom crash handler to ensure it is called when a jitter function panics
runtime.ReallyCrash = false
runtime.PanicHandlers = []func(interface{}){
func(p interface{}) {
handled++
},
}
ch := make(chan struct{})
JitterUntil(func() {
called++
if called > 2 {
close(ch)
return
}
panic("TestJitterUntilRecoversPanic")
}, time.Millisecond, 1.0, true, ch)
if called != 3 {
t.Errorf("Expected panic recovers")
}
}
func TestJitterUntilNegativeFactor(t *testing.T) {
now := time.Now()
ch := make(chan struct{})
called := make(chan struct{})
received := make(chan struct{})
go func() {
JitterUntil(func() {
called <- struct{}{}
<-received
}, time.Second, -30.0, true, ch)
}()
// first loop
<-called
received <- struct{}{}
// second loop
<-called
close(ch)
received <- struct{}{}
// it should take at most 2 seconds + some overhead, not 3
if now.Add(3 * time.Second).Before(time.Now()) {
t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func")
}
}
func TestExponentialBackoff(t *testing.T) {
opts := Backoff{Factor: 1.0, Steps: 3}
// waits up to steps
i := 0
err := ExponentialBackoff(opts, func() (bool, error) {
i++
return false, nil
})
if err != ErrWaitTimeout || i != opts.Steps {
t.Errorf("unexpected error: %v", err)
}
// returns immediately
i = 0
err = ExponentialBackoff(opts, func() (bool, error) {
i++
return true, nil
})
if err != nil || i != 1 {
t.Errorf("unexpected error: %v", err)
}
// returns immediately on error
testErr := fmt.Errorf("some other error")
err = ExponentialBackoff(opts, func() (bool, error) {
return false, testErr
})
if err != testErr {
t.Errorf("unexpected error: %v", err)
}
// invoked multiple times
i = 1
err = ExponentialBackoff(opts, func() (bool, error) {
if i < opts.Steps {
i++
return false, nil
}
return true, nil
})
if err != nil || i != opts.Steps {
t.Errorf("unexpected error: %v", err)
}
}
func TestPoller(t *testing.T) {
done := make(chan struct{})
defer close(done)
w := poller(time.Millisecond, 2*time.Millisecond)
ch := w(done)
count := 0
DRAIN:
for {
select {
case _, open := <-ch:
if !open {
break DRAIN
}
count++
case <-time.After(ForeverTestTimeout):
t.Errorf("unexpected timeout after poll")
}
}
if count > 3 {
t.Errorf("expected up to three values, got %d", count)
}
}
type fakePoller struct {
max int
used int32 // accessed with atomics
wg sync.WaitGroup
}
func fakeTicker(max int, used *int32, doneFunc func()) WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
go func() {
defer doneFunc()
defer close(ch)
for i := 0; i < max; i++ {
select {
case ch <- struct{}{}:
case <-done:
return
}
if used != nil {
atomic.AddInt32(used, 1)
}
}
}()
return ch
}
}
func (fp *fakePoller) GetWaitFunc() WaitFunc {
fp.wg.Add(1)
return fakeTicker(fp.max, &fp.used, fp.wg.Done)
}
func TestPoll(t *testing.T) {
invocations := 0
f := ConditionFunc(func() (bool, error) {
invocations++
return true, nil
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err)
}
fp.wg.Wait()
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used := atomic.LoadInt32(&fp.used)
if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used)
}
}
func TestPollError(t *testing.T) {
expectedError := errors.New("Expected error")
f := ConditionFunc(func() (bool, error) {
return false, expectedError
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
fp.wg.Wait()
used := atomic.LoadInt32(&fp.used)
if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used)
}
}
func TestPollImmediate(t *testing.T) {
invocations := 0
f := ConditionFunc(func() (bool, error) {
invocations++
return true, nil
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil {
t.Fatalf("unexpected error %v", err)
}
// We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all.
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used := atomic.LoadInt32(&fp.used)
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
}
func TestPollImmediateError(t *testing.T) {
expectedError := errors.New("Expected error")
f := ConditionFunc(func() (bool, error) {
return false, expectedError
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
// We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all.
used := atomic.LoadInt32(&fp.used)
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
}
func TestPollForever(t *testing.T) {
ch := make(chan struct{})
done := make(chan struct{}, 1)
complete := make(chan struct{})
go func() {
f := ConditionFunc(func() (bool, error) {
ch <- struct{}{}
select {
case <-done:
return true, nil
default:
}
return false, nil
})
if err := PollInfinite(time.Microsecond, f); err != nil {
t.Fatalf("unexpected error %v", err)
}
close(ch)
complete <- struct{}{}
}()
// ensure the condition is opened
<-ch
// ensure channel sends events
for i := 0; i < 10; i++ {
select {
case _, open := <-ch:
if !open {
t.Fatalf("did not expect channel to be closed")
}
case <-time.After(ForeverTestTimeout):
t.Fatalf("channel did not return at least once within the poll interval")
}
}
// at most one poll notification should be sent once we return from the condition
done <- struct{}{}
go func() {
for i := 0; i < 2; i++ {
_, open := <-ch
if !open {
return
}
}
t.Fatalf("expected closed channel after two iterations")
}()
<-complete
}
func TestWaitFor(t *testing.T) {
var invocations int
testCases := map[string]struct {
F ConditionFunc
Ticks int
Invoked int
Err bool
}{
"invoked once": {
ConditionFunc(func() (bool, error) {
invocations++
return true, nil
}),
2,
1,
false,
},
"invoked and returns a timeout": {
ConditionFunc(func() (bool, error) {
invocations++
return false, nil
}),
2,
3, // the contract of WaitFor() says the func is called once more at the end of the wait
true,
},
"returns immediately on error": {
ConditionFunc(func() (bool, error) {
invocations++
return false, errors.New("test")
}),
2,
1,
true,
},
}
for k, c := range testCases {
invocations = 0
ticker := fakeTicker(c.Ticks, nil, func() {})
err := func() error {
done := make(chan struct{})
defer close(done)
return WaitFor(ticker, c.F, done)
}()
switch {
case c.Err && err == nil:
t.Errorf("%s: Expected error, got nil", k)
continue
case !c.Err && err != nil:
t.Errorf("%s: Expected no error, got: %#v", k, err)
continue
}
if invocations != c.Invoked {
t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations)
}
}
}
func TestWaitForWithDelay(t *testing.T) {
done := make(chan struct{})
defer close(done)
WaitFor(poller(time.Millisecond, ForeverTestTimeout), func() (bool, error) {
time.Sleep(10 * time.Millisecond)
return true, nil
}, done)
// If polling goroutine doesn't see the done signal it will leak timers.
select {
case done <- struct{}{}:
case <-time.After(ForeverTestTimeout):
t.Errorf("expected an ack of the done signal.")
}
}
func TestPollUntil(t *testing.T) {
stopCh := make(chan struct{})
called := make(chan bool)
pollDone := make(chan struct{})
go func() {
PollUntil(time.Microsecond, ConditionFunc(func() (bool, error) {
called <- true
return false, nil
}), stopCh)
close(pollDone)
}()
// make sure we're called once
<-called
// this should trigger a "done"
close(stopCh)
go func() {
// release the condition func if needed
for {
<-called
}
}()
// make sure we finished the poll
<-pollDone
}

View File

@ -1,19 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package waitgroup implements SafeWaitGroup wrap of sync.WaitGroup.
// Add with positive delta when waiting will fail, to prevent sync.WaitGroup race issue.
package waitgroup // import "k8s.io/apimachinery/pkg/util/waitgroup"

View File

@ -1,57 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package waitgroup
import (
"fmt"
"sync"
)
// SafeWaitGroup must not be copied after first use.
type SafeWaitGroup struct {
wg sync.WaitGroup
mu sync.RWMutex
// wait indicate whether Wait is called, if true,
// then any Add with positive delta will return error.
wait bool
}
// Add adds delta, which may be negative, similar to sync.WaitGroup.
// If Add with a positive delta happens after Wait, it will return error,
// which prevent unsafe Add.
func (wg *SafeWaitGroup) Add(delta int) error {
wg.mu.RLock()
defer wg.mu.RUnlock()
if wg.wait && delta > 0 {
return fmt.Errorf("add with positive delta after Wait is forbidden")
}
wg.wg.Add(delta)
return nil
}
// Done decrements the WaitGroup counter.
func (wg *SafeWaitGroup) Done() {
wg.wg.Done()
}
// Wait blocks until the WaitGroup counter is zero.
func (wg *SafeWaitGroup) Wait() {
wg.mu.Lock()
wg.wait = true
wg.mu.Unlock()
wg.wg.Wait()
}

View File

@ -1,60 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package waitgroup test cases reference golang sync.WaitGroup https://golang.org/src/sync/waitgroup_test.go.
package waitgroup
import (
"testing"
)
func TestWaitGroup(t *testing.T) {
wg1 := &SafeWaitGroup{}
wg2 := &SafeWaitGroup{}
n := 16
wg1.Add(n)
wg2.Add(n)
exited := make(chan bool, n)
for i := 0; i != n; i++ {
go func(i int) {
wg1.Done()
wg2.Wait()
exited <- true
}(i)
}
wg1.Wait()
for i := 0; i != n; i++ {
select {
case <-exited:
t.Fatal("SafeWaitGroup released group too soon")
default:
}
wg2.Done()
}
for i := 0; i != n; i++ {
<-exited // Will block if barrier fails to unlock someone.
}
}
func TestWaitGroupAddFail(t *testing.T) {
wg := &SafeWaitGroup{}
wg.Add(1)
wg.Done()
wg.Wait()
if err := wg.Add(1); err == nil {
t.Errorf("Should return error when add positive after Wait")
}
}

View File

@ -1,405 +0,0 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package yaml
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"math/rand"
"reflect"
"strings"
"testing"
)
func TestYAMLDecoderReadBytesLength(t *testing.T) {
d := `---
stuff: 1
test-foo: 1
`
testCases := []struct {
bufLen int
expectLen int
expectErr error
}{
{len(d), len(d), nil},
{len(d) + 10, len(d), nil},
{len(d) - 10, len(d) - 10, io.ErrShortBuffer},
}
for i, testCase := range testCases {
r := NewDocumentDecoder(ioutil.NopCloser(bytes.NewReader([]byte(d))))
b := make([]byte, testCase.bufLen)
n, err := r.Read(b)
if err != testCase.expectErr || n != testCase.expectLen {
t.Fatalf("%d: unexpected body: %d / %v", i, n, err)
}
}
}
func TestYAMLDecoderCallsAfterErrShortBufferRestOfFrame(t *testing.T) {
d := `---
stuff: 1
test-foo: 1`
r := NewDocumentDecoder(ioutil.NopCloser(bytes.NewReader([]byte(d))))
b := make([]byte, 12)
n, err := r.Read(b)
if err != io.ErrShortBuffer || n != 12 {
t.Fatalf("expected ErrShortBuffer: %d / %v", n, err)
}
expected := "---\nstuff: 1"
if string(b) != expected {
t.Fatalf("expected bytes read to be: %s got: %s", expected, string(b))
}
b = make([]byte, 13)
n, err = r.Read(b)
if err != nil || n != 13 {
t.Fatalf("expected nil: %d / %v", n, err)
}
expected = "\n\ttest-foo: 1"
if string(b) != expected {
t.Fatalf("expected bytes read to be: '%s' got: '%s'", expected, string(b))
}
b = make([]byte, 15)
n, err = r.Read(b)
if err != io.EOF || n != 0 {
t.Fatalf("expected EOF: %d / %v", n, err)
}
}
func TestSplitYAMLDocument(t *testing.T) {
testCases := []struct {
input string
atEOF bool
expect string
adv int
}{
{"foo", true, "foo", 3},
{"fo", false, "", 0},
{"---", true, "---", 3},
{"---\n", true, "---\n", 4},
{"---\n", false, "", 0},
{"\n---\n", false, "", 5},
{"\n---\n", true, "", 5},
{"abc\n---\ndef", true, "abc", 8},
{"def", true, "def", 3},
{"", true, "", 0},
}
for i, testCase := range testCases {
adv, token, err := splitYAMLDocument([]byte(testCase.input), testCase.atEOF)
if err != nil {
t.Errorf("%d: unexpected error: %v", i, err)
continue
}
if adv != testCase.adv {
t.Errorf("%d: advance did not match: %d %d", i, testCase.adv, adv)
}
if testCase.expect != string(token) {
t.Errorf("%d: token did not match: %q %q", i, testCase.expect, string(token))
}
}
}
func TestGuessJSON(t *testing.T) {
if r, _, isJSON := GuessJSONStream(bytes.NewReader([]byte(" \n{}")), 100); !isJSON {
t.Fatalf("expected stream to be JSON")
} else {
b := make([]byte, 30)
n, err := r.Read(b)
if err != nil || n != 4 {
t.Fatalf("unexpected body: %d / %v", n, err)
}
if string(b[:n]) != " \n{}" {
t.Fatalf("unexpected body: %q", string(b[:n]))
}
}
}
func TestScanYAML(t *testing.T) {
s := bufio.NewScanner(bytes.NewReader([]byte(`---
stuff: 1
---
`)))
s.Split(splitYAMLDocument)
if !s.Scan() {
t.Fatalf("should have been able to scan")
}
t.Logf("scan: %s", s.Text())
if !s.Scan() {
t.Fatalf("should have been able to scan")
}
t.Logf("scan: %s", s.Text())
if s.Scan() {
t.Fatalf("scan should have been done")
}
if s.Err() != nil {
t.Fatalf("err should have been nil: %v", s.Err())
}
}
func TestDecodeYAML(t *testing.T) {
s := NewYAMLToJSONDecoder(bytes.NewReader([]byte(`---
stuff: 1
---
`)))
obj := generic{}
if err := s.Decode(&obj); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if fmt.Sprintf("%#v", obj) != `yaml.generic{"stuff":1}` {
t.Errorf("unexpected object: %#v", obj)
}
obj = generic{}
if err := s.Decode(&obj); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(obj) != 0 {
t.Fatalf("unexpected object: %#v", obj)
}
obj = generic{}
if err := s.Decode(&obj); err != io.EOF {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDecodeBrokenYAML(t *testing.T) {
s := NewYAMLOrJSONDecoder(bytes.NewReader([]byte(`---
stuff: 1
test-foo: 1
---
`)), 100)
obj := generic{}
err := s.Decode(&obj)
if err == nil {
t.Fatal("expected error with yaml: violate, got no error")
}
fmt.Printf("err: %s\n", err.Error())
if !strings.Contains(err.Error(), "yaml: line 3:") {
t.Fatalf("expected %q to have 'yaml: line 3:' found a tab character", err.Error())
}
}
func TestDecodeBrokenJSON(t *testing.T) {
s := NewYAMLOrJSONDecoder(bytes.NewReader([]byte(`{
"foo": {
"stuff": 1
"otherStuff": 2
}
}
`)), 100)
obj := generic{}
err := s.Decode(&obj)
if err == nil {
t.Fatal("expected error with json: prefix, got no error")
}
if !strings.HasPrefix(err.Error(), "json: line 3:") {
t.Fatalf("expected %q to have 'json: line 3:' prefix", err.Error())
}
}
type generic map[string]interface{}
func TestYAMLOrJSONDecoder(t *testing.T) {
testCases := []struct {
input string
buffer int
isJSON bool
err bool
out []generic
}{
{` {"1":2}{"3":4}`, 2, true, false, []generic{
{"1": 2},
{"3": 4},
}},
{" \n{}", 3, true, false, []generic{
{},
}},
{" \na: b", 2, false, false, []generic{
{"a": "b"},
}},
{" \n{\"a\": \"b\"}", 2, false, true, []generic{
{"a": "b"},
}},
{" \n{\"a\": \"b\"}", 3, true, false, []generic{
{"a": "b"},
}},
{` {"a":"b"}`, 100, true, false, []generic{
{"a": "b"},
}},
{"", 1, false, false, []generic{}},
{"foo: bar\n---\nbaz: biz", 100, false, false, []generic{
{"foo": "bar"},
{"baz": "biz"},
}},
{"foo: bar\n---\n", 100, false, false, []generic{
{"foo": "bar"},
}},
{"foo: bar\n---", 100, false, false, []generic{
{"foo": "bar"},
}},
{"foo: bar\n--", 100, false, true, []generic{
{"foo": "bar"},
}},
{"foo: bar\n-", 100, false, true, []generic{
{"foo": "bar"},
}},
{"foo: bar\n", 100, false, false, []generic{
{"foo": "bar"},
}},
}
for i, testCase := range testCases {
decoder := NewYAMLOrJSONDecoder(bytes.NewReader([]byte(testCase.input)), testCase.buffer)
objs := []generic{}
var err error
for {
out := make(generic)
err = decoder.Decode(&out)
if err != nil {
break
}
objs = append(objs, out)
}
if err != io.EOF {
switch {
case testCase.err && err == nil:
t.Errorf("%d: unexpected non-error", i)
continue
case !testCase.err && err != nil:
t.Errorf("%d: unexpected error: %v", i, err)
continue
case err != nil:
continue
}
}
switch decoder.decoder.(type) {
case *YAMLToJSONDecoder:
if testCase.isJSON {
t.Errorf("%d: expected JSON decoder, got YAML", i)
}
case *json.Decoder:
if !testCase.isJSON {
t.Errorf("%d: expected YAML decoder, got JSON", i)
}
}
if fmt.Sprintf("%#v", testCase.out) != fmt.Sprintf("%#v", objs) {
t.Errorf("%d: objects were not equal: \n%#v\n%#v", i, testCase.out, objs)
}
}
}
func TestReadSingleLongLine(t *testing.T) {
testReadLines(t, []int{128 * 1024})
}
func TestReadRandomLineLengths(t *testing.T) {
minLength := 100
maxLength := 96 * 1024
maxLines := 100
lineLengths := make([]int, maxLines)
for i := 0; i < maxLines; i++ {
lineLengths[i] = rand.Intn(maxLength-minLength) + minLength
}
testReadLines(t, lineLengths)
}
func testReadLines(t *testing.T, lineLengths []int) {
var (
lines [][]byte
inputStream []byte
)
for _, lineLength := range lineLengths {
inputLine := make([]byte, lineLength+1)
for i := 0; i < lineLength; i++ {
char := rand.Intn('z'-'A') + 'A'
inputLine[i] = byte(char)
}
inputLine[len(inputLine)-1] = '\n'
lines = append(lines, inputLine)
}
for _, line := range lines {
inputStream = append(inputStream, line...)
}
// init Reader
reader := bufio.NewReader(bytes.NewReader(inputStream))
lineReader := &LineReader{reader: reader}
// read lines
var readLines [][]byte
for range lines {
bytes, err := lineReader.Read()
if err != nil && err != io.EOF {
t.Fatalf("failed to read lines: %v", err)
}
readLines = append(readLines, bytes)
}
// validate
for i := range lines {
if len(lines[i]) != len(readLines[i]) {
t.Fatalf("expected line length: %d, but got %d", len(lines[i]), len(readLines[i]))
}
if !reflect.DeepEqual(lines[i], readLines[i]) {
t.Fatalf("expected line: %v, but got %v", lines[i], readLines[i])
}
}
}
func TestTypedJSONOrYamlErrors(t *testing.T) {
s := NewYAMLOrJSONDecoder(bytes.NewReader([]byte(`{
"foo": {
"stuff": 1
"otherStuff": 2
}
}
`)), 100)
obj := generic{}
err := s.Decode(&obj)
if err == nil {
t.Fatal("expected error with json: prefix, got no error")
}
if _, ok := err.(JSONSyntaxError); !ok {
t.Fatalf("expected %q to be of type JSONSyntaxError", err.Error())
}
s = NewYAMLOrJSONDecoder(bytes.NewReader([]byte(`---
stuff: 1
test-foo: 1
---
`)), 100)
obj = generic{}
err = s.Decode(&obj)
if err == nil {
t.Fatal("expected error with yaml: prefix, got no error")
}
if _, ok := err.(YAMLSyntaxError); !ok {
t.Fatalf("expected %q to be of type YAMLSyntaxError", err.Error())
}
}