vendor files

This commit is contained in:
Serguei Bezverkhi
2018-01-09 13:57:14 -05:00
parent 558bc6c02a
commit 7b24313bd6
16547 changed files with 4527373 additions and 0 deletions

View File

@ -0,0 +1,63 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = [
"errors.go",
"request_cache.go",
"server.go",
],
importpath = "k8s.io/kubernetes/pkg/kubelet/server/streaming",
deps = [
"//pkg/kubelet/apis/cri/v1alpha1/runtime:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//pkg/kubelet/server/remotecommand:go_default_library",
"//vendor/github.com/emicklei/go-restful:go_default_library",
"//vendor/google.golang.org/grpc:go_default_library",
"//vendor/google.golang.org/grpc/codes:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/types:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/remotecommand:go_default_library",
"//vendor/k8s.io/client-go/tools/remotecommand:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"request_cache_test.go",
"server_test.go",
],
importpath = "k8s.io/kubernetes/pkg/kubelet/server/streaming",
library = ":go_default_library",
deps = [
"//pkg/apis/core:go_default_library",
"//pkg/kubelet/apis/cri/v1alpha1/runtime:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
"//vendor/github.com/stretchr/testify/require:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//vendor/k8s.io/client-go/rest:go_default_library",
"//vendor/k8s.io/client-go/tools/remotecommand:go_default_library",
"//vendor/k8s.io/client-go/transport/spdy:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)

View File

@ -0,0 +1,55 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"fmt"
"net/http"
"strconv"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
func ErrorStreamingDisabled(method string) error {
return grpc.Errorf(codes.NotFound, fmt.Sprintf("streaming method %s disabled", method))
}
// The error returned when the maximum number of in-flight requests is exceeded.
func ErrorTooManyInFlight() error {
return grpc.Errorf(codes.ResourceExhausted, "maximum number of in-flight requests exceeded")
}
// Translates a CRI streaming error into an appropriate HTTP response.
func WriteError(err error, w http.ResponseWriter) error {
var status int
switch grpc.Code(err) {
case codes.NotFound:
status = http.StatusNotFound
case codes.ResourceExhausted:
// We only expect to hit this if there is a DoS, so we just wait the full TTL.
// If this is ever hit in steady-state operations, consider increasing the MaxInFlight requests,
// or plumbing through the time to next expiration.
w.Header().Set("Retry-After", strconv.Itoa(int(CacheTTL.Seconds())))
status = http.StatusTooManyRequests
default:
status = http.StatusInternalServerError
}
w.WriteHeader(status)
_, writeErr := w.Write([]byte(err.Error()))
return writeErr
}

View File

@ -0,0 +1,146 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"container/list"
"crypto/rand"
"encoding/base64"
"fmt"
"math"
"sync"
"time"
"k8s.io/apimachinery/pkg/util/clock"
)
var (
// Timeout after which tokens become invalid.
CacheTTL = 1 * time.Minute
// The maximum number of in-flight requests to allow.
MaxInFlight = 1000
// Length of the random base64 encoded token identifying the request.
TokenLen = 8
)
// requestCache caches streaming (exec/attach/port-forward) requests and generates a single-use
// random token for their retrieval. The requestCache is used for building streaming URLs without
// the need to encode every request parameter in the URL.
type requestCache struct {
// clock is used to obtain the current time
clock clock.Clock
// tokens maps the generate token to the request for fast retrieval.
tokens map[string]*list.Element
// ll maintains an age-ordered request list for faster garbage collection of expired requests.
ll *list.List
lock sync.Mutex
}
// Type representing an *ExecRequest, *AttachRequest, or *PortForwardRequest.
type request interface{}
type cacheEntry struct {
token string
req request
expireTime time.Time
}
func newRequestCache() *requestCache {
return &requestCache{
clock: clock.RealClock{},
ll: list.New(),
tokens: make(map[string]*list.Element),
}
}
// Insert the given request into the cache and returns the token used for fetching it out.
func (c *requestCache) Insert(req request) (token string, err error) {
c.lock.Lock()
defer c.lock.Unlock()
// Remove expired entries.
c.gc()
// If the cache is full, reject the request.
if c.ll.Len() == MaxInFlight {
return "", ErrorTooManyInFlight()
}
token, err = c.uniqueToken()
if err != nil {
return "", err
}
ele := c.ll.PushFront(&cacheEntry{token, req, c.clock.Now().Add(CacheTTL)})
c.tokens[token] = ele
return token, nil
}
// Consume the token (remove it from the cache) and return the cached request, if found.
func (c *requestCache) Consume(token string) (req request, found bool) {
c.lock.Lock()
defer c.lock.Unlock()
ele, ok := c.tokens[token]
if !ok {
return nil, false
}
c.ll.Remove(ele)
delete(c.tokens, token)
entry := ele.Value.(*cacheEntry)
if c.clock.Now().After(entry.expireTime) {
// Entry already expired.
return nil, false
}
return entry.req, true
}
// uniqueToken generates a random URL-safe token and ensures uniqueness.
func (c *requestCache) uniqueToken() (string, error) {
const maxTries = 10
// Number of bytes to be TokenLen when base64 encoded.
tokenSize := math.Ceil(float64(TokenLen) * 6 / 8)
rawToken := make([]byte, int(tokenSize))
for i := 0; i < maxTries; i++ {
if _, err := rand.Read(rawToken); err != nil {
return "", err
}
encoded := base64.RawURLEncoding.EncodeToString(rawToken)
token := encoded[:TokenLen]
// If it's unique, return it. Otherwise retry.
if _, exists := c.tokens[encoded]; !exists {
return token, nil
}
}
return "", fmt.Errorf("failed to generate unique token")
}
// Must be write-locked prior to calling.
func (c *requestCache) gc() {
now := c.clock.Now()
for c.ll.Len() > 0 {
oldest := c.ll.Back()
entry := oldest.Value.(*cacheEntry)
if !now.After(entry.expireTime) {
return
}
// Oldest value is expired; remove it.
c.ll.Remove(oldest)
delete(c.tokens, entry.token)
}
}

View File

@ -0,0 +1,221 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/clock"
)
func TestInsert(t *testing.T) {
c, _ := newTestCache()
// Insert normal
oldestTok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, oldestTok, TokenLen)
assertCacheSize(t, c, 1)
// Insert until full
for i := 0; i < MaxInFlight-2; i++ {
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, tok, TokenLen)
}
assertCacheSize(t, c, MaxInFlight-1)
newestReq := nextRequest()
newestTok, err := c.Insert(newestReq)
require.NoError(t, err)
assert.Len(t, newestTok, TokenLen)
assertCacheSize(t, c, MaxInFlight)
require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached")
// Consume newest token.
req, ok := c.Consume(newestTok)
assert.True(t, ok, "newest request should still be cached")
assert.Equal(t, newestReq, req)
require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached")
// Insert again (still full)
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, tok, TokenLen)
assertCacheSize(t, c, MaxInFlight)
// Insert again (should evict)
_, err = c.Insert(nextRequest())
assert.Error(t, err, "should reject further requests")
errResponse := httptest.NewRecorder()
require.NoError(t, WriteError(err, errResponse))
assert.Equal(t, errResponse.Code, http.StatusTooManyRequests)
assert.Equal(t, strconv.Itoa(int(CacheTTL.Seconds())), errResponse.HeaderMap.Get("Retry-After"))
assertCacheSize(t, c, MaxInFlight)
_, ok = c.Consume(oldestTok)
assert.True(t, ok, "oldest request should be valid")
}
func TestConsume(t *testing.T) {
c, clock := newTestCache()
{ // Insert & consume.
req := nextRequest()
tok, err := c.Insert(req)
require.NoError(t, err)
assertCacheSize(t, c, 1)
cachedReq, ok := c.Consume(tok)
assert.True(t, ok)
assert.Equal(t, req, cachedReq)
assertCacheSize(t, c, 0)
}
{ // Insert & consume out of order
req1 := nextRequest()
tok1, err := c.Insert(req1)
require.NoError(t, err)
assertCacheSize(t, c, 1)
req2 := nextRequest()
tok2, err := c.Insert(req2)
require.NoError(t, err)
assertCacheSize(t, c, 2)
cachedReq2, ok := c.Consume(tok2)
assert.True(t, ok)
assert.Equal(t, req2, cachedReq2)
assertCacheSize(t, c, 1)
cachedReq1, ok := c.Consume(tok1)
assert.True(t, ok)
assert.Equal(t, req1, cachedReq1)
assertCacheSize(t, c, 0)
}
{ // Consume a second time
req := nextRequest()
tok, err := c.Insert(req)
require.NoError(t, err)
assertCacheSize(t, c, 1)
cachedReq, ok := c.Consume(tok)
assert.True(t, ok)
assert.Equal(t, req, cachedReq)
assertCacheSize(t, c, 0)
_, ok = c.Consume(tok)
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
{ // Consume without insert
_, ok := c.Consume("fooBAR")
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
{ // Consume expired
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(2 * CacheTTL)
_, ok := c.Consume(tok)
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
}
func TestGC(t *testing.T) {
c, clock := newTestCache()
// When empty
c.gc()
assertCacheSize(t, c, 0)
tok1, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(10 * time.Second)
tok2, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 2)
// expired: tok1, tok2
// non-expired: tok3, tok4
clock.Step(2 * CacheTTL)
tok3, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(10 * time.Second)
tok4, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 2)
_, ok := c.Consume(tok1)
assert.False(t, ok)
_, ok = c.Consume(tok2)
assert.False(t, ok)
_, ok = c.Consume(tok3)
assert.True(t, ok)
_, ok = c.Consume(tok4)
assert.True(t, ok)
// When full, nothing is expired.
for i := 0; i < MaxInFlight; i++ {
_, err := c.Insert(nextRequest())
require.NoError(t, err)
}
assertCacheSize(t, c, MaxInFlight)
// When everything is expired
clock.Step(2 * CacheTTL)
_, err = c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
}
func newTestCache() (*requestCache, *clock.FakeClock) {
c := newRequestCache()
fakeClock := clock.NewFakeClock(time.Now())
c.clock = fakeClock
return c, fakeClock
}
func assertCacheSize(t *testing.T, cache *requestCache, expectedSize int) {
tokenLen := len(cache.tokens)
llLen := cache.ll.Len()
assert.Equal(t, tokenLen, llLen, "inconsistent cache size! len(tokens)=%d; len(ll)=%d", tokenLen, llLen)
assert.Equal(t, expectedSize, tokenLen, "unexpected cache size!")
}
var requestUID = 0
func nextRequest() interface{} {
requestUID++
return requestUID
}

View File

@ -0,0 +1,374 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"crypto/tls"
"errors"
"io"
"net/http"
"net/url"
"path"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
restful "github.com/emicklei/go-restful"
"k8s.io/apimachinery/pkg/types"
remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/client-go/tools/remotecommand"
runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/v1alpha1/runtime"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
)
// The library interface to serve the stream requests.
type Server interface {
http.Handler
// Get the serving URL for the requests.
// Requests must not be nil. Responses may be nil iff an error is returned.
GetExec(*runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error)
GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error)
GetPortForward(*runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error)
// Start the server.
// addr is the address to serve on (address:port) stayUp indicates whether the server should
// listen until Stop() is called, or automatically stop after all expected connections are
// closed. Calling Get{Exec,Attach,PortForward} increments the expected connection count.
// Function does not return until the server is stopped.
Start(stayUp bool) error
// Stop the server, and terminate any open connections.
Stop() error
}
// The interface to execute the commands and provide the streams.
type Runtime interface {
Exec(containerID string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error
Attach(containerID string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error
PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error
}
// Config defines the options used for running the stream server.
type Config struct {
// The host:port address the server will listen on.
Addr string
// The optional base URL for constructing streaming URLs. If empty, the baseURL will be
// constructed from the serve address.
BaseURL *url.URL
// How long to leave idle connections open for.
StreamIdleTimeout time.Duration
// How long to wait for clients to create streams. Only used for SPDY streaming.
StreamCreationTimeout time.Duration
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedRemoteCommandProtocols []string
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedPortForwardProtocols []string
// The config for serving over TLS. If nil, TLS will not be used.
TLSConfig *tls.Config
}
// DefaultConfig provides default values for server Config. The DefaultConfig is partial, so
// some fields like Addr must still be provided.
var DefaultConfig = Config{
StreamIdleTimeout: 4 * time.Hour,
StreamCreationTimeout: remotecommandconsts.DefaultStreamCreationTimeout,
SupportedRemoteCommandProtocols: remotecommandconsts.SupportedStreamingProtocols,
SupportedPortForwardProtocols: portforward.SupportedProtocols,
}
// TODO(tallclair): Add auth(n/z) interface & handling.
func NewServer(config Config, runtime Runtime) (Server, error) {
s := &server{
config: config,
runtime: &criAdapter{runtime},
cache: newRequestCache(),
}
if s.config.BaseURL == nil {
s.config.BaseURL = &url.URL{
Scheme: "http",
Host: s.config.Addr,
}
if s.config.TLSConfig != nil {
s.config.BaseURL.Scheme = "https"
}
}
ws := &restful.WebService{}
endpoints := []struct {
path string
handler restful.RouteFunction
}{
{"/exec/{token}", s.serveExec},
{"/attach/{token}", s.serveAttach},
{"/portforward/{token}", s.servePortForward},
}
// If serving relative to a base path, set that here.
pathPrefix := path.Dir(s.config.BaseURL.Path)
for _, e := range endpoints {
for _, method := range []string{"GET", "POST"} {
ws.Route(ws.
Method(method).
Path(path.Join(pathPrefix, e.path)).
To(e.handler))
}
}
handler := restful.NewContainer()
handler.Add(ws)
s.handler = handler
s.server = &http.Server{
Addr: s.config.Addr,
Handler: s.handler,
TLSConfig: s.config.TLSConfig,
}
return s, nil
}
type server struct {
config Config
runtime *criAdapter
handler http.Handler
cache *requestCache
server *http.Server
}
func validateExecRequest(req *runtimeapi.ExecRequest) error {
if req.ContainerId == "" {
return grpc.Errorf(codes.InvalidArgument, "missing required container_id")
}
if req.Tty && req.Stderr {
// If TTY is set, stderr cannot be true because multiplexing is not
// supported.
return grpc.Errorf(codes.InvalidArgument, "tty and stderr cannot both be true")
}
if !req.Stdin && !req.Stdout && !req.Stderr {
return grpc.Errorf(codes.InvalidArgument, "one of stdin, stdout, or stderr must be set")
}
return nil
}
func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) {
if err := validateExecRequest(req); err != nil {
return nil, err
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.ExecResponse{
Url: s.buildURL("exec", token),
}, nil
}
func validateAttachRequest(req *runtimeapi.AttachRequest) error {
if req.ContainerId == "" {
return grpc.Errorf(codes.InvalidArgument, "missing required container_id")
}
if req.Tty && req.Stderr {
// If TTY is set, stderr cannot be true because multiplexing is not
// supported.
return grpc.Errorf(codes.InvalidArgument, "tty and stderr cannot both be true")
}
if !req.Stdin && !req.Stdout && !req.Stderr {
return grpc.Errorf(codes.InvalidArgument, "one of stdin, stdout, and stderr must be set")
}
return nil
}
func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) {
if err := validateAttachRequest(req); err != nil {
return nil, err
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.AttachResponse{
Url: s.buildURL("attach", token),
}, nil
}
func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) {
if req.PodSandboxId == "" {
return nil, grpc.Errorf(codes.InvalidArgument, "missing required pod_sandbox_id")
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.PortForwardResponse{
Url: s.buildURL("portforward", token),
}, nil
}
func (s *server) Start(stayUp bool) error {
if !stayUp {
// TODO(tallclair): Implement this.
return errors.New("stayUp=false is not yet implemented")
}
if s.config.TLSConfig != nil {
return s.server.ListenAndServeTLS("", "") // Use certs from TLSConfig.
} else {
return s.server.ListenAndServe()
}
}
func (s *server) Stop() error {
return s.server.Close()
}
func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handler.ServeHTTP(w, r)
}
func (s *server) buildURL(method, token string) string {
return s.config.BaseURL.ResolveReference(&url.URL{
Path: path.Join(method, token),
}).String()
}
func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
exec, ok := cachedRequest.(*runtimeapi.ExecRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
streamOpts := &remotecommandserver.Options{
Stdin: exec.Stdin,
Stdout: exec.Stdout,
Stderr: exec.Stderr,
TTY: exec.Tty,
}
remotecommandserver.ServeExec(
resp.ResponseWriter,
req.Request,
s.runtime,
"", // unused: podName
"", // unusued: podUID
exec.ContainerId,
exec.Cmd,
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
attach, ok := cachedRequest.(*runtimeapi.AttachRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
streamOpts := &remotecommandserver.Options{
Stdin: attach.Stdin,
Stdout: attach.Stdout,
Stderr: attach.Stderr,
TTY: attach.Tty,
}
remotecommandserver.ServeAttach(
resp.ResponseWriter,
req.Request,
s.runtime,
"", // unused: podName
"", // unusued: podUID
attach.ContainerId,
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) servePortForward(req *restful.Request, resp *restful.Response) {
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
portForwardOptions, err := portforward.BuildV4Options(pf.Port)
if err != nil {
resp.WriteError(http.StatusBadRequest, err)
return
}
portforward.ServePortForward(
resp.ResponseWriter,
req.Request,
s.runtime,
pf.PodSandboxId,
"", // unused: podUID
portForwardOptions,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedPortForwardProtocols)
}
// criAdapter wraps the Runtime functions to conform to the remotecommand interfaces.
// The adapter binds the container ID to the container name argument, and the pod sandbox ID to the pod name.
type criAdapter struct {
Runtime
}
var _ remotecommandserver.Executor = &criAdapter{}
var _ remotecommandserver.Attacher = &criAdapter{}
var _ portforward.PortForwarder = &criAdapter{}
func (a *criAdapter) ExecInContainer(podName string, podUID types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error {
return a.Runtime.Exec(container, cmd, in, out, err, tty, resize)
}
func (a *criAdapter) AttachContainer(podName string, podUID types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
return a.Runtime.Attach(container, in, out, err, tty, resize)
}
func (a *criAdapter) PortForward(podName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error {
return a.Runtime.PortForward(podName, port, stream)
}

View File

@ -0,0 +1,469 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"crypto/tls"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/tools/remotecommand"
"k8s.io/client-go/transport/spdy"
api "k8s.io/kubernetes/pkg/apis/core"
runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/v1alpha1/runtime"
kubeletportforward "k8s.io/kubernetes/pkg/kubelet/server/portforward"
)
const (
testAddr = "localhost:12345"
testContainerID = "container789"
testPodSandboxID = "pod0987"
)
func TestGetExec(t *testing.T) {
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
assert.NoError(t, err)
tlsServer, err := NewServer(Config{
Addr: testAddr,
TLSConfig: &tls.Config{},
}, nil)
assert.NoError(t, err)
const pathPrefix = "cri/shim"
prefixServer, err := NewServer(Config{
Addr: testAddr,
BaseURL: &url.URL{
Scheme: "http",
Host: testAddr,
Path: "/" + pathPrefix + "/",
},
}, nil)
assert.NoError(t, err)
assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) {
req, ok := cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, expectedReq, req)
}
request := &runtimeapi.ExecRequest{
ContainerId: testContainerID,
Cmd: []string{"echo", "foo"},
Tty: true,
Stdin: true,
}
{ // Non-TLS
resp, err := serv.GetExec(request)
assert.NoError(t, err)
expectedURL := "http://" + testAddr + "/exec/"
assert.Contains(t, resp.Url, expectedURL)
token := strings.TrimPrefix(resp.Url, expectedURL)
assertRequestToken(request, serv.(*server).cache, token)
}
{ // TLS
resp, err := tlsServer.GetExec(request)
assert.NoError(t, err)
expectedURL := "https://" + testAddr + "/exec/"
assert.Contains(t, resp.Url, expectedURL)
token := strings.TrimPrefix(resp.Url, expectedURL)
assertRequestToken(request, tlsServer.(*server).cache, token)
}
{ // Path prefix
resp, err := prefixServer.GetExec(request)
assert.NoError(t, err)
expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
assert.Contains(t, resp.Url, expectedURL)
token := strings.TrimPrefix(resp.Url, expectedURL)
assertRequestToken(request, prefixServer.(*server).cache, token)
}
}
func TestValidateExecAttachRequest(t *testing.T) {
type config struct {
tty bool
stdin bool
stdout bool
stderr bool
}
for _, tc := range []struct {
desc string
configs []config
expectErr bool
}{
{
desc: "at least one stream must be true",
expectErr: true,
configs: []config{
{false, false, false, false},
{true, false, false, false}},
},
{
desc: "tty and stderr cannot both be true",
expectErr: true,
configs: []config{
{true, false, false, true},
{true, false, true, true},
{true, true, false, true},
{true, true, true, true},
},
},
{
desc: "a valid config should pass",
expectErr: false,
configs: []config{
{false, false, false, true},
{false, false, true, false},
{false, false, true, true},
{false, true, false, false},
{false, true, false, true},
{false, true, true, false},
{false, true, true, true},
{true, false, true, false},
{true, true, false, false},
{true, true, true, false},
},
},
} {
t.Run(tc.desc, func(t *testing.T) {
for _, c := range tc.configs {
// validate the exec request.
execReq := &runtimeapi.ExecRequest{
ContainerId: testContainerID,
Cmd: []string{"date"},
Tty: c.tty,
Stdin: c.stdin,
Stdout: c.stdout,
Stderr: c.stderr,
}
err := validateExecRequest(execReq)
assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
// validate the attach request.
attachReq := &runtimeapi.AttachRequest{
ContainerId: testContainerID,
Tty: c.tty,
Stdin: c.stdin,
Stdout: c.stdout,
Stderr: c.stderr,
}
err = validateAttachRequest(attachReq)
assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
}
})
}
}
func TestGetAttach(t *testing.T) {
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
require.NoError(t, err)
tlsServer, err := NewServer(Config{
Addr: testAddr,
TLSConfig: &tls.Config{},
}, nil)
require.NoError(t, err)
assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) {
req, ok := cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, expectedReq, req)
}
request := &runtimeapi.AttachRequest{
ContainerId: testContainerID,
Stdin: true,
Tty: true,
}
{ // Non-TLS
resp, err := serv.GetAttach(request)
assert.NoError(t, err)
expectedURL := "http://" + testAddr + "/attach/"
assert.Contains(t, resp.Url, expectedURL)
token := strings.TrimPrefix(resp.Url, expectedURL)
assertRequestToken(request, serv.(*server).cache, token)
}
{ // TLS
resp, err := tlsServer.GetAttach(request)
assert.NoError(t, err)
expectedURL := "https://" + testAddr + "/attach/"
assert.Contains(t, resp.Url, expectedURL)
token := strings.TrimPrefix(resp.Url, expectedURL)
assertRequestToken(request, tlsServer.(*server).cache, token)
}
}
func TestGetPortForward(t *testing.T) {
podSandboxID := testPodSandboxID
request := &runtimeapi.PortForwardRequest{
PodSandboxId: podSandboxID,
Port: []int32{1, 2, 3, 4},
}
{ // Non-TLS
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
assert.NoError(t, err)
resp, err := serv.GetPortForward(request)
assert.NoError(t, err)
expectedURL := "http://" + testAddr + "/portforward/"
assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
token := strings.TrimPrefix(resp.Url, expectedURL)
req, ok := serv.(*server).cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
}
{ // TLS
tlsServer, err := NewServer(Config{
Addr: testAddr,
TLSConfig: &tls.Config{},
}, nil)
assert.NoError(t, err)
resp, err := tlsServer.GetPortForward(request)
assert.NoError(t, err)
expectedURL := "https://" + testAddr + "/portforward/"
assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
token := strings.TrimPrefix(resp.Url, expectedURL)
req, ok := tlsServer.(*server).cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
}
}
func TestServeExec(t *testing.T) {
runRemoteCommandTest(t, "exec")
}
func TestServeAttach(t *testing.T) {
runRemoteCommandTest(t, "attach")
}
func TestServePortForward(t *testing.T) {
s, testServer := startTestServer(t)
defer testServer.Close()
resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
PodSandboxId: testPodSandboxID,
})
require.NoError(t, err)
reqURL, err := url.Parse(resp.Url)
require.NoError(t, err)
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
require.NoError(t, err)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL)
streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name)
require.NoError(t, err)
defer streamConn.Close()
// Create the streams.
headers := http.Header{}
// Error stream is required, but unused in this test.
headers.Set(api.StreamType, api.StreamTypeError)
headers.Set(api.PortHeader, strconv.Itoa(testPort))
_, err = streamConn.CreateStream(headers)
require.NoError(t, err)
// Setup the data stream.
headers.Set(api.StreamType, api.StreamTypeData)
headers.Set(api.PortHeader, strconv.Itoa(testPort))
stream, err := streamConn.CreateStream(headers)
require.NoError(t, err)
doClientStreams(t, "portforward", stream, stream, nil)
}
//
// Run the remote command test.
// commandType is either "exec" or "attach".
func runRemoteCommandTest(t *testing.T, commandType string) {
s, testServer := startTestServer(t)
defer testServer.Close()
var reqURL *url.URL
stdin, stdout, stderr := true, true, true
containerID := testContainerID
switch commandType {
case "exec":
resp, err := s.GetExec(&runtimeapi.ExecRequest{
ContainerId: containerID,
Cmd: []string{"echo"},
Stdin: stdin,
Stdout: stdout,
Stderr: stderr,
})
require.NoError(t, err)
reqURL, err = url.Parse(resp.Url)
require.NoError(t, err)
case "attach":
resp, err := s.GetAttach(&runtimeapi.AttachRequest{
ContainerId: containerID,
Stdin: stdin,
Stdout: stdout,
Stderr: stderr,
})
require.NoError(t, err)
reqURL, err = url.Parse(resp.Url)
require.NoError(t, err)
}
wg := sync.WaitGroup{}
wg.Add(2)
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
stderrR, stderrW := io.Pipe()
go func() {
defer wg.Done()
exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL)
require.NoError(t, err)
opts := remotecommand.StreamOptions{
Stdin: stdinR,
Stdout: stdoutW,
Stderr: stderrW,
Tty: false,
}
require.NoError(t, exec.Stream(opts))
}()
go func() {
defer wg.Done()
doClientStreams(t, commandType, stdinW, stdoutR, stderrR)
}()
wg.Wait()
// Repeat request with the same URL should be a 404.
resp, err := http.Get(reqURL.String())
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func startTestServer(t *testing.T) (Server, *httptest.Server) {
var s Server
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.ServeHTTP(w, r)
}))
cleanup := true
defer func() {
if cleanup {
testServer.Close()
}
}()
testURL, err := url.Parse(testServer.URL)
require.NoError(t, err)
rt := newFakeRuntime(t)
config := DefaultConfig
config.BaseURL = testURL
s, err = NewServer(config, rt)
require.NoError(t, err)
cleanup = false // Caller must close the test server.
return s, testServer
}
const (
testInput = "abcdefg"
testOutput = "fooBARbaz"
testErr = "ERROR!!!"
testPort = 12345
)
func newFakeRuntime(t *testing.T) *fakeRuntime {
return &fakeRuntime{
t: t,
}
}
type fakeRuntime struct {
t *testing.T
}
func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
assert.Equal(f.t, testContainerID, containerID)
doServerStreams(f.t, "exec", stdin, stdout, stderr)
return nil
}
func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
assert.Equal(f.t, testContainerID, containerID)
doServerStreams(f.t, "attach", stdin, stdout, stderr)
return nil
}
func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
assert.Equal(f.t, testPodSandboxID, podSandboxID)
assert.EqualValues(f.t, testPort, port)
doServerStreams(f.t, "portforward", stream, stream, nil)
return nil
}
// Send & receive expected input/output. Must be the inverse of doClientStreams.
// Function will block until the expected i/o is finished.
func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) {
if stderr != nil {
writeExpected(t, "server stderr", stderr, prefix+testErr)
}
readExpected(t, "server stdin", stdin, prefix+testInput)
writeExpected(t, "server stdout", stdout, prefix+testOutput)
}
// Send & receive expected input/output. Must be the inverse of doServerStreams.
// Function will block until the expected i/o is finished.
func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) {
if stderr != nil {
readExpected(t, "client stderr", stderr, prefix+testErr)
}
writeExpected(t, "client stdin", stdin, prefix+testInput)
readExpected(t, "client stdout", stdout, prefix+testOutput)
}
// Read and verify the expected string from the stream.
func readExpected(t *testing.T, streamName string, r io.Reader, expected string) {
result := make([]byte, len(expected))
_, err := io.ReadAtLeast(r, result, len(expected))
assert.NoError(t, err, "stream %s", streamName)
assert.Equal(t, expected, string(result), "stream %s", streamName)
}
// Write and verify success of the data over the stream.
func writeExpected(t *testing.T, streamName string, w io.Writer, data string) {
n, err := io.WriteString(w, data)
assert.NoError(t, err, "stream %s", streamName)
assert.Equal(t, len(data), n, "stream %s", streamName)
}