mirror of
https://github.com/ceph/ceph-csi.git
synced 2024-12-04 20:20:19 +00:00
903 lines
27 KiB
Go
903 lines
27 KiB
Go
|
package kmip
|
||
|
|
||
|
// This is a WIP implementation of a KMIP server. The code is mostly based on the http server in
|
||
|
// the golang standard library. It is functional, but not all of the features of the http server
|
||
|
// have been ported over yet, and some of the stuff in here still refers to http stuff.
|
||
|
//
|
||
|
// The responsibility of handling a request is broken up into 3 layers of handlers: ProtocolHandler, MessageHandler,
|
||
|
// and ItemHandler. Each of these handlers delegates details to the next layer. Using the http
|
||
|
// package as an analogy, ProtocolHandler is similar to the wire-level HTTP protocol handling in
|
||
|
// http.Server and http.Transport. MessageHandler parses KMIP TTLV bytes into golang request and response structs.
|
||
|
// ItemHandler is a bit like http.ServeMux, routing particular KMIP operations to registered handlers.
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"runtime"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/ansel1/merry"
|
||
|
"github.com/gemalto/flume"
|
||
|
"github.com/gemalto/kmip-go/kmip14"
|
||
|
"github.com/gemalto/kmip-go/ttlv"
|
||
|
"github.com/google/uuid"
|
||
|
)
|
||
|
|
||
|
var serverLog = flume.New("kmip_server")
|
||
|
|
||
|
// Server serves KMIP protocol connections from a net.Listener. Because KMIP is a connection-oriented
|
||
|
// protocol, unlike HTTP, each connection ends up being serviced by a dedicated goroutine (rather than
|
||
|
// each request). For each KMIP connection, requests are processed serially. The handling
|
||
|
// of the request is delegated to the ProtocolHandler.
|
||
|
//
|
||
|
// Limitations:
|
||
|
//
|
||
|
// This implementation is functional (it can respond to KMIP requests), but incomplete. Some of the
|
||
|
// connection management features of the http package haven't been ported over, and also, there is
|
||
|
// currently no connection-context in which to store things like an authentication or session management.
|
||
|
// Since HTTP is an intrinsically stateless model, it makes sense for the http package to delegate session
|
||
|
// management to third party packages, but for KMIP, it would makes sense for there to be some first
|
||
|
// class support for a connection context.
|
||
|
//
|
||
|
// This package also only handles the binary TTLV encoding for now. It may make sense for this
|
||
|
// server to detect or support the XML and JSON encodings as well. It may also makes sense to support
|
||
|
// KMIP requests over HTTP, perhaps by adapting ProtocolHandler to an http.Handler or something.
|
||
|
type Server struct {
|
||
|
Handler ProtocolHandler
|
||
|
|
||
|
mu sync.Mutex
|
||
|
listeners map[*net.Listener]struct{}
|
||
|
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
|
||
|
}
|
||
|
|
||
|
// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
|
||
|
// and ListenAndServeTLS methods after a call to Shutdown or Close.
|
||
|
var ErrServerClosed = errors.New("http: Server closed")
|
||
|
|
||
|
// Serve accepts incoming connections on the Listener l, creating a
|
||
|
// new service goroutine for each. The service goroutines read requests and
|
||
|
// then call srv.MessageHandler to reply to them.
|
||
|
//
|
||
|
// Serve always returns a non-nil error and closes l.
|
||
|
// After Shutdown or Close, the returned error is ErrServerClosed.
|
||
|
func (srv *Server) Serve(l net.Listener) error {
|
||
|
//if fn := testHookServerServe; fn != nil {
|
||
|
// fn(srv, l) // call hook with unwrapped listener
|
||
|
//}
|
||
|
|
||
|
l = &onceCloseListener{Listener: l}
|
||
|
defer l.Close()
|
||
|
|
||
|
if !srv.trackListener(&l, true) {
|
||
|
return ErrServerClosed
|
||
|
}
|
||
|
defer srv.trackListener(&l, false)
|
||
|
|
||
|
var tempDelay time.Duration // how long to sleep on accept failure
|
||
|
baseCtx := context.Background() // base is always background, per Issue 16220
|
||
|
ctx := baseCtx
|
||
|
// ctx := context.WithValue(baseCtx, ServerContextKey, srv)
|
||
|
for {
|
||
|
rw, e := l.Accept()
|
||
|
if e != nil {
|
||
|
if srv.shuttingDown() {
|
||
|
return ErrServerClosed
|
||
|
}
|
||
|
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||
|
if tempDelay == 0 {
|
||
|
tempDelay = 5 * time.Millisecond
|
||
|
} else {
|
||
|
tempDelay *= 2
|
||
|
}
|
||
|
if max := 1 * time.Second; tempDelay > max {
|
||
|
tempDelay = max
|
||
|
}
|
||
|
// srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
|
||
|
time.Sleep(tempDelay)
|
||
|
continue
|
||
|
}
|
||
|
return e
|
||
|
}
|
||
|
tempDelay = 0
|
||
|
c := &conn{server: srv, rwc: rw}
|
||
|
// c.setState(c.rwc, StateNew) // before Serve can return
|
||
|
go c.serve(ctx)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Close immediately closes all active net.Listeners and any
|
||
|
// connections in state StateNew, StateActive, or StateIdle. For a
|
||
|
// graceful shutdown, use Shutdown.
|
||
|
//
|
||
|
// Close does not attempt to close (and does not even know about)
|
||
|
// any hijacked connections, such as WebSockets.
|
||
|
//
|
||
|
// Close returns any error returned from closing the Server's
|
||
|
// underlying Listener(s).
|
||
|
func (srv *Server) Close() error {
|
||
|
atomic.StoreInt32(&srv.inShutdown, 1)
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
// srv.closeDoneChanLocked()
|
||
|
err := srv.closeListenersLocked()
|
||
|
//for c := range srv.activeConn {
|
||
|
// c.rwc.Close()
|
||
|
// delete(srv.activeConn, c)
|
||
|
//}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// shutdownPollInterval is how often we poll for quiescence
|
||
|
// during Server.Shutdown. This is lower during tests, to
|
||
|
// speed up tests.
|
||
|
// Ideally we could find a solution that doesn't involve polling,
|
||
|
// but which also doesn't have a high runtime cost (and doesn't
|
||
|
// involve any contentious mutexes), but that is left as an
|
||
|
// exercise for the reader.
|
||
|
var shutdownPollInterval = 500 * time.Millisecond
|
||
|
|
||
|
// Shutdown gracefully shuts down the server without interrupting any
|
||
|
// active connections. Shutdown works by first closing all open
|
||
|
// listeners, then closing all idle connections, and then waiting
|
||
|
// indefinitely for connections to return to idle and then shut down.
|
||
|
// If the provided context expires before the shutdown is complete,
|
||
|
// Shutdown returns the context's error, otherwise it returns any
|
||
|
// error returned from closing the Server's underlying Listener(s).
|
||
|
//
|
||
|
// When Shutdown is called, Serve, ListenAndServe, and
|
||
|
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
|
||
|
// program doesn't exit and waits instead for Shutdown to return.
|
||
|
//
|
||
|
// Shutdown does not attempt to close nor wait for hijacked
|
||
|
// connections such as WebSockets. The caller of Shutdown should
|
||
|
// separately notify such long-lived connections of shutdown and wait
|
||
|
// for them to close, if desired. See RegisterOnShutdown for a way to
|
||
|
// register shutdown notification functions.
|
||
|
//
|
||
|
// Once Shutdown has been called on a server, it may not be reused;
|
||
|
// future calls to methods such as Serve will return ErrServerClosed.
|
||
|
func (srv *Server) Shutdown(ctx context.Context) error {
|
||
|
atomic.StoreInt32(&srv.inShutdown, 1)
|
||
|
|
||
|
srv.mu.Lock()
|
||
|
lnerr := srv.closeListenersLocked()
|
||
|
//srv.closeDoneChanLocked()
|
||
|
//for _, f := range srv.onShutdown {
|
||
|
// go f()
|
||
|
//}
|
||
|
srv.mu.Unlock()
|
||
|
|
||
|
ticker := time.NewTicker(shutdownPollInterval)
|
||
|
defer ticker.Stop()
|
||
|
return lnerr
|
||
|
//for {
|
||
|
// if srv.closeIdleConns() {
|
||
|
// return lnerr
|
||
|
// }
|
||
|
// select {
|
||
|
// case <-ctx.Done():
|
||
|
// return ctx.Err()
|
||
|
// case <-ticker.C:
|
||
|
// }
|
||
|
//}
|
||
|
}
|
||
|
|
||
|
func (srv *Server) closeListenersLocked() error {
|
||
|
var err error
|
||
|
for ln := range srv.listeners {
|
||
|
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
||
|
err = cerr
|
||
|
}
|
||
|
delete(srv.listeners, ln)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// trackListener adds or removes a net.Listener to the set of tracked
|
||
|
// listeners.
|
||
|
//
|
||
|
// We store a pointer to interface in the map set, in case the
|
||
|
// net.Listener is not comparable. This is safe because we only call
|
||
|
// trackListener via Serve and can track+defer untrack the same
|
||
|
// pointer to local variable there. We never need to compare a
|
||
|
// Listener from another caller.
|
||
|
//
|
||
|
// It reports whether the server is still up (not Shutdown or Closed).
|
||
|
func (srv *Server) trackListener(ln *net.Listener, add bool) bool {
|
||
|
srv.mu.Lock()
|
||
|
defer srv.mu.Unlock()
|
||
|
if srv.listeners == nil {
|
||
|
srv.listeners = make(map[*net.Listener]struct{})
|
||
|
}
|
||
|
if add {
|
||
|
if srv.shuttingDown() {
|
||
|
return false
|
||
|
}
|
||
|
srv.listeners[ln] = struct{}{}
|
||
|
} else {
|
||
|
delete(srv.listeners, ln)
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (srv *Server) shuttingDown() bool {
|
||
|
return atomic.LoadInt32(&srv.inShutdown) != 0
|
||
|
}
|
||
|
|
||
|
type conn struct {
|
||
|
rwc net.Conn
|
||
|
remoteAddr string
|
||
|
localAddr string
|
||
|
tlsState *tls.ConnectionState
|
||
|
// cancelCtx cancels the connection-level context.
|
||
|
cancelCtx context.CancelFunc
|
||
|
|
||
|
// bufr reads from rwc.
|
||
|
bufr *bufio.Reader
|
||
|
dec *ttlv.Decoder
|
||
|
|
||
|
server *Server
|
||
|
}
|
||
|
|
||
|
func (c *conn) close() {
|
||
|
// TODO: http package has a buffered writer on the conn too, which is flushed here
|
||
|
_ = c.rwc.Close()
|
||
|
}
|
||
|
|
||
|
// Serve a new connection.
|
||
|
func (c *conn) serve(ctx context.Context) {
|
||
|
ctx = flume.WithLogger(ctx, serverLog)
|
||
|
ctx, cancelCtx := context.WithCancel(ctx)
|
||
|
c.cancelCtx = cancelCtx
|
||
|
c.remoteAddr = c.rwc.RemoteAddr().String()
|
||
|
c.localAddr = c.rwc.LocalAddr().String()
|
||
|
// ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
|
||
|
defer func() {
|
||
|
if err := recover(); err != nil {
|
||
|
// TODO: logging support
|
||
|
// if err := recover(); err != nil && err != ErrAbortHandler {
|
||
|
const size = 64 << 10
|
||
|
buf := make([]byte, size)
|
||
|
buf = buf[:runtime.Stack(buf, false)]
|
||
|
if e, ok := err.(error); ok {
|
||
|
fmt.Printf("kmip: panic serving %v: %v\n%s", c.remoteAddr, Details(e), buf)
|
||
|
} else {
|
||
|
fmt.Printf("kmip: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
|
||
|
}
|
||
|
|
||
|
// c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
|
||
|
}
|
||
|
cancelCtx()
|
||
|
// if !c.hijacked() {
|
||
|
c.close()
|
||
|
// c.setState(c.rwc, StateClosed)
|
||
|
//}
|
||
|
}()
|
||
|
|
||
|
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
|
||
|
//if d := c.server.ReadTimeout; d != 0 {
|
||
|
// c.rwc.SetReadDeadline(time.Now().Add(d))
|
||
|
//}
|
||
|
//if d := c.server.WriteTimeout; d != 0 {
|
||
|
// c.rwc.SetWriteDeadline(time.Now().Add(d))
|
||
|
//}
|
||
|
if err := tlsConn.Handshake(); err != nil {
|
||
|
// TODO: logging support
|
||
|
fmt.Printf("kmip: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
|
||
|
// c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
|
||
|
return
|
||
|
}
|
||
|
c.tlsState = new(tls.ConnectionState)
|
||
|
*c.tlsState = tlsConn.ConnectionState()
|
||
|
//if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
|
||
|
// if fn := c.server.TLSNextProto[proto]; fn != nil {
|
||
|
// h := initNPNRequest{tlsConn, serverHandler{c.server}}
|
||
|
// fn(c.server, tlsConn, h)
|
||
|
// }
|
||
|
// return
|
||
|
//}
|
||
|
}
|
||
|
|
||
|
// TODO: do we really need instance pooling here? We expect KMIP connections to be long lasting
|
||
|
c.dec = ttlv.NewDecoder(c.rwc)
|
||
|
c.bufr = bufio.NewReader(c.rwc)
|
||
|
// c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
|
||
|
|
||
|
for {
|
||
|
w, err := c.readRequest(ctx)
|
||
|
//if c.r.remain != c.server.initialReadLimitSize() {
|
||
|
// If we read any bytes off the wire, we're active.
|
||
|
//c.setState(c.rwc, StateActive)
|
||
|
//}
|
||
|
if err != nil {
|
||
|
if merry.Is(err, io.EOF) {
|
||
|
fmt.Println("client closed connection")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// TODO: do something with this error
|
||
|
panic(err)
|
||
|
//const errorHeaders= "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
|
||
|
//
|
||
|
//if err == errTooLarge {
|
||
|
// // Their HTTP client may or may not be
|
||
|
// // able to read this if we're
|
||
|
// // responding to them and hanging up
|
||
|
// // while they're still writing their
|
||
|
// // request. Undefined behavior.
|
||
|
// const publicErr= "431 Request Header Fields Too Large"
|
||
|
// fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
|
||
|
// c.closeWriteAndWait()
|
||
|
// return
|
||
|
//}
|
||
|
//if isCommonNetReadError(err) {
|
||
|
// return // don't reply
|
||
|
//}
|
||
|
//
|
||
|
//publicErr := "400 Bad Request"
|
||
|
//if v, ok := err.(badRequestError); ok {
|
||
|
// publicErr = publicErr + ": " + string(v)
|
||
|
//}
|
||
|
//
|
||
|
//fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
|
||
|
//return
|
||
|
}
|
||
|
|
||
|
// Expect 100 Continue support
|
||
|
//req := w.req
|
||
|
//if req.expectsContinue() {
|
||
|
// if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
|
||
|
// // Wrap the Body reader with one that replies on the connection
|
||
|
// req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
|
||
|
// }
|
||
|
//} else if req.Header.get("Expect") != "" {
|
||
|
// w.sendExpectationFailed()
|
||
|
// return
|
||
|
//}
|
||
|
|
||
|
// c.curReq.Store(w)
|
||
|
|
||
|
//if requestBodyRemains(req.Body) {
|
||
|
// registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
|
||
|
//} else {
|
||
|
// w.conn.r.startBackgroundRead()
|
||
|
//}
|
||
|
|
||
|
// HTTP cannot have multiple simultaneous active requests.[*]
|
||
|
// Until the server replies to this request, it can't read another,
|
||
|
// so we might as well run the handler in this goroutine.
|
||
|
// [*] Not strictly true: HTTP pipelining. We could let them all process
|
||
|
// in parallel even if their responses need to be serialized.
|
||
|
// But we're not going to implement HTTP pipelining because it
|
||
|
// was never deployed in the wild and the answer is HTTP/2.
|
||
|
|
||
|
h := c.server.Handler
|
||
|
if h == nil {
|
||
|
h = DefaultProtocolHandler
|
||
|
}
|
||
|
|
||
|
// var resp ResponseMessage
|
||
|
// err = c.server.MessageHandler.Handle(ctx, w, &resp)
|
||
|
// TODO: this cancelCtx() was created at the connection level, not the request level. Need to
|
||
|
// figure out how to handle connection vs request timeouts and cancels.
|
||
|
// cancelCtx()
|
||
|
|
||
|
// TODO: use recycled buffered writer
|
||
|
writer := bufio.NewWriter(c.rwc)
|
||
|
h.ServeKMIP(ctx, w, writer)
|
||
|
err = writer.Flush()
|
||
|
if err != nil {
|
||
|
// TODO: handle error
|
||
|
panic(err)
|
||
|
}
|
||
|
|
||
|
//serverHandler{c.server}.ServeHTTP(w, w.req)
|
||
|
//w.cancelCtx()
|
||
|
//if c.hijacked() {
|
||
|
// return
|
||
|
//}
|
||
|
//w.finishRequest()
|
||
|
//if !w.shouldReuseConnection() {
|
||
|
// if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
|
||
|
// c.closeWriteAndWait()
|
||
|
// }
|
||
|
// return
|
||
|
//}
|
||
|
//c.setState(c.rwc, StateIdle)
|
||
|
//c.curReq.Store((*response)(nil))
|
||
|
|
||
|
//if !w.conn.server.doKeepAlives() {
|
||
|
// // We're in shutdown mode. We might've replied
|
||
|
// // to the user without "Connection: close" and
|
||
|
// // they might think they can send another
|
||
|
// // request, but such is life with HTTP/1.1.
|
||
|
// return
|
||
|
//}
|
||
|
//
|
||
|
//if d := c.server.idleTimeout(); d != 0 {
|
||
|
// c.rwc.SetReadDeadline(time.Now().Add(d))
|
||
|
// if _, err := c.bufr.Peek(4); err != nil {
|
||
|
// return
|
||
|
// }
|
||
|
//}
|
||
|
//c.rwc.SetReadDeadline(time.Time{})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Read next request from connection.
|
||
|
func (c *conn) readRequest(ctx context.Context) (w *Request, err error) {
|
||
|
//if c.hijacked() {
|
||
|
// return nil, ErrHijacked
|
||
|
//}
|
||
|
|
||
|
//var (
|
||
|
// wholeReqDeadline time.Time // or zero if none
|
||
|
// hdrDeadline time.Time // or zero if none
|
||
|
//)
|
||
|
//t0 := time.Now()
|
||
|
//if d := c.server.readHeaderTimeout(); d != 0 {
|
||
|
// hdrDeadline = t0.Add(d)
|
||
|
//}
|
||
|
//if d := c.server.ReadTimeout; d != 0 {
|
||
|
// wholeReqDeadline = t0.Add(d)
|
||
|
//}
|
||
|
//c.rwc.SetReadDeadline(hdrDeadline)
|
||
|
//if d := c.server.WriteTimeout; d != 0 {
|
||
|
// defer func() {
|
||
|
// c.rwc.SetWriteDeadline(time.Now().Add(d))
|
||
|
// }()
|
||
|
//}
|
||
|
|
||
|
//c.r.setReadLimit(c.server.initialReadLimitSize())
|
||
|
//if c.lastMethod == "POST" {
|
||
|
// RFC 7230 section 3 tolerance for old buggy clients.
|
||
|
//peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
|
||
|
//c.bufr.Discard(numLeadingCRorLF(peek))
|
||
|
//}
|
||
|
ttlvVal, err := c.dec.NextTTLV()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
//if err != nil {
|
||
|
//if c.r.hitReadLimit() {
|
||
|
// return nil, errTooLarge
|
||
|
//}
|
||
|
//}
|
||
|
|
||
|
// TODO: use pooling to recycle requests?
|
||
|
req := &Request{
|
||
|
TTLV: ttlvVal,
|
||
|
RemoteAddr: c.remoteAddr,
|
||
|
LocalAddr: c.localAddr,
|
||
|
TLS: c.tlsState,
|
||
|
}
|
||
|
|
||
|
// c.r.setInfiniteReadLimit()
|
||
|
|
||
|
// Adjust the read deadline if necessary.
|
||
|
//if !hdrDeadline.Equal(wholeReqDeadline) {
|
||
|
// c.rwc.SetReadDeadline(wholeReqDeadline)
|
||
|
//}
|
||
|
|
||
|
return req, nil
|
||
|
}
|
||
|
|
||
|
// Request represents a KMIP request.
|
||
|
type Request struct {
|
||
|
// TTLV will hold the entire body of the request.
|
||
|
TTLV ttlv.TTLV
|
||
|
Message *RequestMessage
|
||
|
CurrentItem *RequestBatchItem
|
||
|
DisallowExtraValues bool
|
||
|
|
||
|
// TLS holds the TLS state of the connection this request was received on.
|
||
|
TLS *tls.ConnectionState
|
||
|
RemoteAddr string
|
||
|
LocalAddr string
|
||
|
|
||
|
IDPlaceholder string
|
||
|
|
||
|
decoder *ttlv.Decoder
|
||
|
}
|
||
|
|
||
|
// coerceToTTLV attempts to coerce an interface value to TTLV.
|
||
|
// In most production scenarios, this is intended to be used in
|
||
|
// places where the value is already a TTLV, and just needs to be
|
||
|
// type cast. If v is not TTLV, it will be marshaled. This latter
|
||
|
// behavior is slow, so it should be used only in tests.
|
||
|
func coerceToTTLV(v interface{}) (ttlv.TTLV, error) {
|
||
|
switch t := v.(type) {
|
||
|
case nil:
|
||
|
return nil, nil
|
||
|
case ttlv.TTLV:
|
||
|
return t, nil
|
||
|
default:
|
||
|
return ttlv.Marshal(v)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Unmarshal unmarshals ttlv into structures. Handlers should prefer this
|
||
|
// method over than their own Decoders or Unmarshal(). This method
|
||
|
// enforces rules about whether extra fields are allowed, and reuses
|
||
|
// buffers for efficiency.
|
||
|
func (r *Request) Unmarshal(ttlv ttlv.TTLV, into interface{}) error {
|
||
|
if len(ttlv) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
r.decoder.Reset(bytes.NewReader(ttlv))
|
||
|
return r.decoder.Decode(into)
|
||
|
}
|
||
|
|
||
|
func (r *Request) DecodePayload(v interface{}) error {
|
||
|
if r.CurrentItem == nil {
|
||
|
return nil
|
||
|
}
|
||
|
ttlvVal, err := coerceToTTLV(r.CurrentItem.RequestPayload)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return r.Unmarshal(ttlvVal, v)
|
||
|
}
|
||
|
|
||
|
// onceCloseListener wraps a net.Listener, protecting it from
|
||
|
// multiple Close calls.
|
||
|
type onceCloseListener struct {
|
||
|
net.Listener
|
||
|
once sync.Once
|
||
|
closeErr error
|
||
|
}
|
||
|
|
||
|
func (oc *onceCloseListener) Close() error {
|
||
|
oc.once.Do(oc.close)
|
||
|
return oc.closeErr
|
||
|
}
|
||
|
|
||
|
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||
|
|
||
|
type ResponseWriter interface {
|
||
|
io.Writer
|
||
|
}
|
||
|
|
||
|
// ProtocolHandler is responsible for handling raw requests read off the wire. The
|
||
|
// *Request object will only have TTLV field populated. The response should
|
||
|
// be written directly to the ResponseWriter.
|
||
|
//
|
||
|
// The default implemention of ProtocolHandler is StandardProtocolHandler.
|
||
|
type ProtocolHandler interface {
|
||
|
ServeKMIP(ctx context.Context, req *Request, resp ResponseWriter)
|
||
|
}
|
||
|
|
||
|
// MessageHandler handles KMIP requests which have already be decoded. The *Request
|
||
|
// object's Message field will be populated from the decoded TTLV. The *Response
|
||
|
// object will always be non-nil, and its ResponseHeader will be populated. The
|
||
|
// MessageHandler usually shouldn't modify the ResponseHeader: the ProtocolHandler
|
||
|
// is responsible for the header. The MessageHandler just needs to populate
|
||
|
// the response batch items.
|
||
|
//
|
||
|
// The default implementation of MessageHandler is OperationMux.
|
||
|
type MessageHandler interface {
|
||
|
HandleMessage(ctx context.Context, req *Request, resp *Response)
|
||
|
}
|
||
|
|
||
|
// ItemHandler handles a single batch item in a KMIP request. The *Request
|
||
|
// object's CurrentItem field will be populated with item to be handled.
|
||
|
type ItemHandler interface {
|
||
|
HandleItem(ctx context.Context, req *Request) (item *ResponseBatchItem, err error)
|
||
|
}
|
||
|
|
||
|
type ProtocolHandlerFunc func(context.Context, *Request, ResponseWriter)
|
||
|
|
||
|
func (f ProtocolHandlerFunc) ServeKMIP(ctx context.Context, r *Request, w ResponseWriter) {
|
||
|
f(ctx, r, w)
|
||
|
}
|
||
|
|
||
|
type MessageHandlerFunc func(context.Context, *Request, *Response)
|
||
|
|
||
|
func (f MessageHandlerFunc) HandleMessage(ctx context.Context, req *Request, resp *Response) {
|
||
|
f(ctx, req, resp)
|
||
|
}
|
||
|
|
||
|
type ItemHandlerFunc func(context.Context, *Request) (*ResponseBatchItem, error)
|
||
|
|
||
|
func (f ItemHandlerFunc) HandleItem(ctx context.Context, req *Request) (item *ResponseBatchItem, err error) {
|
||
|
return f(ctx, req)
|
||
|
}
|
||
|
|
||
|
var DefaultProtocolHandler = &StandardProtocolHandler{
|
||
|
MessageHandler: DefaultOperationMux,
|
||
|
ProtocolVersion: ProtocolVersion{
|
||
|
ProtocolVersionMajor: 1,
|
||
|
ProtocolVersionMinor: 4,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
var DefaultOperationMux = &OperationMux{}
|
||
|
|
||
|
// StandardProtocolHandler is the default ProtocolHandler implementation. It
|
||
|
// handles decoding the request and encoding the response, as well as protocol
|
||
|
// level tasks like version negotiation and correlation values.
|
||
|
//
|
||
|
// It delegates handling of the request to a MessageHandler.
|
||
|
type StandardProtocolHandler struct {
|
||
|
ProtocolVersion ProtocolVersion
|
||
|
MessageHandler MessageHandler
|
||
|
|
||
|
LogTraffic bool
|
||
|
}
|
||
|
|
||
|
func (h *StandardProtocolHandler) parseMessage(ctx context.Context, req *Request) error {
|
||
|
ttlvV := req.TTLV
|
||
|
if err := ttlvV.Valid(); err != nil {
|
||
|
return merry.Prepend(err, "invalid ttlv")
|
||
|
}
|
||
|
|
||
|
if ttlvV.Tag() != kmip14.TagRequestMessage {
|
||
|
return merry.Errorf("invalid tag: expected RequestMessage, was %s", ttlvV.Tag().String())
|
||
|
}
|
||
|
|
||
|
var message RequestMessage
|
||
|
err := ttlv.Unmarshal(ttlvV, &message)
|
||
|
if err != nil {
|
||
|
return merry.Prepend(err, "failed to parse message")
|
||
|
}
|
||
|
|
||
|
req.Message = &message
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var responsePool = sync.Pool{}
|
||
|
|
||
|
type Response struct {
|
||
|
ResponseMessage
|
||
|
buf bytes.Buffer
|
||
|
enc *ttlv.Encoder
|
||
|
}
|
||
|
|
||
|
func newResponse() *Response {
|
||
|
v := responsePool.Get()
|
||
|
if v != nil {
|
||
|
r := v.(*Response)
|
||
|
r.reset()
|
||
|
return r
|
||
|
}
|
||
|
r := Response{}
|
||
|
r.enc = ttlv.NewEncoder(&r.buf)
|
||
|
return &r
|
||
|
}
|
||
|
|
||
|
func releaseResponse(r *Response) {
|
||
|
responsePool.Put(r)
|
||
|
}
|
||
|
|
||
|
func (r *Response) reset() {
|
||
|
r.BatchItem = nil
|
||
|
r.ResponseMessage = ResponseMessage{}
|
||
|
r.buf.Reset()
|
||
|
}
|
||
|
|
||
|
func (r *Response) Bytes() []byte {
|
||
|
r.buf.Reset()
|
||
|
err := r.enc.Encode(&r.ResponseMessage)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
|
||
|
return r.buf.Bytes()
|
||
|
}
|
||
|
|
||
|
func (r *Response) errorResponse(reason kmip14.ResultReason, msg string) {
|
||
|
r.BatchItem = []ResponseBatchItem{
|
||
|
{
|
||
|
ResultStatus: kmip14.ResultStatusOperationFailed,
|
||
|
ResultReason: reason,
|
||
|
ResultMessage: msg,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *StandardProtocolHandler) handleRequest(ctx context.Context, req *Request, resp *Response) (logger flume.Logger) {
|
||
|
// create a server correlation value, which is like a unique transaction ID
|
||
|
scv := uuid.New().String()
|
||
|
|
||
|
// create a logger for the transaction, seeded with the scv
|
||
|
logger = flume.FromContext(ctx).With("scv", scv)
|
||
|
// attach the logger to the context, so it is available to the handling chain
|
||
|
ctx = flume.WithLogger(ctx, logger)
|
||
|
|
||
|
// TODO: it's unclear how the full protocol negogiation is supposed to work
|
||
|
// should server be pinned to a particular version? Or should we try and negogiate a common version?
|
||
|
resp.ResponseHeader.ProtocolVersion = h.ProtocolVersion
|
||
|
resp.ResponseHeader.TimeStamp = time.Now()
|
||
|
resp.ResponseHeader.BatchCount = len(resp.BatchItem)
|
||
|
resp.ResponseHeader.ServerCorrelationValue = scv
|
||
|
|
||
|
if err := h.parseMessage(ctx, req); err != nil {
|
||
|
resp.errorResponse(kmip14.ResultReasonInvalidMessage, err.Error())
|
||
|
return
|
||
|
}
|
||
|
|
||
|
ccv := req.Message.RequestHeader.ClientCorrelationValue
|
||
|
// add the client correlation value to the logging context. This value uniquely
|
||
|
// identifies the client, and is supposed to be included in server logs
|
||
|
logger = logger.With("ccv", ccv)
|
||
|
ctx = flume.WithLogger(ctx, logger)
|
||
|
resp.ResponseHeader.ClientCorrelationValue = req.Message.RequestHeader.ClientCorrelationValue
|
||
|
|
||
|
clientMajorVersion := req.Message.RequestHeader.ProtocolVersion.ProtocolVersionMajor
|
||
|
if clientMajorVersion != h.ProtocolVersion.ProtocolVersionMajor {
|
||
|
resp.errorResponse(kmip14.ResultReasonInvalidMessage,
|
||
|
fmt.Sprintf("mismatched protocol versions, client: %d, server: %d", clientMajorVersion, h.ProtocolVersion.ProtocolVersionMajor))
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// set a flag hinting to handlers that extra fields should not be tolerated when
|
||
|
// unmarshaling payloads. According to spec, if server and client protocol version
|
||
|
// minor versions match, then extra fields should cause an error. Not sure how to enforce
|
||
|
// this in this higher level handler, since we (the protocol/message handlers) don't unmarshal the payload.
|
||
|
// That's done by a particular item handler.
|
||
|
req.DisallowExtraValues = req.Message.RequestHeader.ProtocolVersion.ProtocolVersionMinor == h.ProtocolVersion.ProtocolVersionMinor
|
||
|
req.decoder = ttlv.NewDecoder(nil)
|
||
|
req.decoder.DisallowExtraValues = req.DisallowExtraValues
|
||
|
|
||
|
h.MessageHandler.HandleMessage(ctx, req, resp)
|
||
|
resp.ResponseHeader.BatchCount = len(resp.BatchItem)
|
||
|
|
||
|
respTTLV := resp.Bytes()
|
||
|
|
||
|
if req.Message.RequestHeader.MaximumResponseSize > 0 && len(respTTLV) > req.Message.RequestHeader.MaximumResponseSize {
|
||
|
// new error resp
|
||
|
resp.errorResponse(kmip14.ResultReasonResponseTooLarge, "")
|
||
|
respTTLV = resp.Bytes()
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (h *StandardProtocolHandler) ServeKMIP(ctx context.Context, req *Request, writer ResponseWriter) {
|
||
|
// we precreate the response object and pass it down to handlers, because due
|
||
|
// the guidance in the spec on the Maximum Response Size, it will be necessary
|
||
|
// for handlers to recalculate the response size after each batch item, which
|
||
|
// requires re-encoding the entire response. Seems inefficient.
|
||
|
resp := newResponse()
|
||
|
logger := h.handleRequest(ctx, req, resp)
|
||
|
|
||
|
var err error
|
||
|
if h.LogTraffic {
|
||
|
ttlvV := resp.Bytes()
|
||
|
|
||
|
logger.Debug("traffic log", "request", req.TTLV.String(), "response", ttlv.TTLV(ttlvV).String())
|
||
|
_, err = writer.Write(ttlvV)
|
||
|
} else {
|
||
|
_, err = resp.buf.WriteTo(writer)
|
||
|
}
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
|
||
|
releaseResponse(resp)
|
||
|
}
|
||
|
|
||
|
func (r *ResponseMessage) addFailure(reason kmip14.ResultReason, msg string) {
|
||
|
if msg == "" {
|
||
|
msg = reason.String()
|
||
|
}
|
||
|
r.BatchItem = append(r.BatchItem, ResponseBatchItem{
|
||
|
ResultStatus: kmip14.ResultStatusOperationFailed,
|
||
|
ResultReason: reason,
|
||
|
ResultMessage: msg,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// OperationMux is an implementation of MessageHandler which handles each batch item in the request
|
||
|
// by routing the operation to an ItemHandler. The ItemHandler performs the operation, and returns
|
||
|
// either a *ResponseBatchItem, or an error. If it returns an error, the error is passed to
|
||
|
// ErrorHandler, which converts it into a error *ResponseBatchItem. OperationMux handles correlating
|
||
|
// items in the request to items in the response.
|
||
|
type OperationMux struct {
|
||
|
mu sync.RWMutex
|
||
|
handlers map[kmip14.Operation]ItemHandler
|
||
|
// ErrorHandler defaults to the DefaultErrorHandler.
|
||
|
ErrorHandler ErrorHandler
|
||
|
}
|
||
|
|
||
|
// ErrorHandler converts a golang error into a *ResponseBatchItem (which should hold information
|
||
|
// about the error to convey back to the client).
|
||
|
type ErrorHandler interface {
|
||
|
HandleError(err error) *ResponseBatchItem
|
||
|
}
|
||
|
|
||
|
type ErrorHandlerFunc func(err error) *ResponseBatchItem
|
||
|
|
||
|
func (f ErrorHandlerFunc) HandleError(err error) *ResponseBatchItem {
|
||
|
return f(err)
|
||
|
}
|
||
|
|
||
|
// DefaultErrorHandler tries to map errors to ResultReasons.
|
||
|
var DefaultErrorHandler = ErrorHandlerFunc(func(err error) *ResponseBatchItem {
|
||
|
reason := GetResultReason(err)
|
||
|
if reason == kmip14.ResultReason(0) {
|
||
|
// error not handled
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// prefer user message, but fall back on message
|
||
|
msg := merry.UserMessage(err)
|
||
|
if msg == "" {
|
||
|
msg = merry.Message(err)
|
||
|
}
|
||
|
return newFailedResponseBatchItem(reason, msg)
|
||
|
})
|
||
|
|
||
|
func newFailedResponseBatchItem(reason kmip14.ResultReason, msg string) *ResponseBatchItem {
|
||
|
return &ResponseBatchItem{
|
||
|
ResultStatus: kmip14.ResultStatusOperationFailed,
|
||
|
ResultReason: reason,
|
||
|
ResultMessage: msg,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *OperationMux) bi(ctx context.Context, req *Request, reqItem *RequestBatchItem) *ResponseBatchItem {
|
||
|
req.CurrentItem = reqItem
|
||
|
h := m.handlerForOp(reqItem.Operation)
|
||
|
if h == nil {
|
||
|
return newFailedResponseBatchItem(kmip14.ResultReasonOperationNotSupported, "")
|
||
|
}
|
||
|
|
||
|
resp, err := h.HandleItem(ctx, req)
|
||
|
if err != nil {
|
||
|
eh := m.ErrorHandler
|
||
|
if eh == nil {
|
||
|
eh = DefaultErrorHandler
|
||
|
}
|
||
|
resp = eh.HandleError(err)
|
||
|
if resp == nil {
|
||
|
// errors which don't convert just panic
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return resp
|
||
|
}
|
||
|
|
||
|
func (m *OperationMux) HandleMessage(ctx context.Context, req *Request, resp *Response) {
|
||
|
for i := range req.Message.BatchItem {
|
||
|
reqItem := &req.Message.BatchItem[i]
|
||
|
respItem := m.bi(ctx, req, reqItem)
|
||
|
respItem.Operation = reqItem.Operation
|
||
|
respItem.UniqueBatchItemID = reqItem.UniqueBatchItemID
|
||
|
resp.BatchItem = append(resp.BatchItem, *respItem)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *OperationMux) Handle(op kmip14.Operation, handler ItemHandler) {
|
||
|
m.mu.Lock()
|
||
|
defer m.mu.Unlock()
|
||
|
|
||
|
if m.handlers == nil {
|
||
|
m.handlers = map[kmip14.Operation]ItemHandler{}
|
||
|
}
|
||
|
|
||
|
m.handlers[op] = handler
|
||
|
}
|
||
|
|
||
|
func (m *OperationMux) handlerForOp(op kmip14.Operation) ItemHandler {
|
||
|
m.mu.RLock()
|
||
|
defer m.mu.RUnlock()
|
||
|
|
||
|
return m.handlers[op]
|
||
|
}
|
||
|
|
||
|
func (m *OperationMux) missingHandler(ctx context.Context, req *Request, resp *ResponseMessage) error {
|
||
|
resp.addFailure(kmip14.ResultReasonOperationNotSupported, "")
|
||
|
return nil
|
||
|
}
|