ceph-csi/vendor/github.com/gemalto/kmip-go/requests.go

903 lines
27 KiB
Go
Raw Normal View History

package kmip
// This is a WIP implementation of a KMIP server. The code is mostly based on the http server in
// the golang standard library. It is functional, but not all of the features of the http server
// have been ported over yet, and some of the stuff in here still refers to http stuff.
//
// The responsibility of handling a request is broken up into 3 layers of handlers: ProtocolHandler, MessageHandler,
// and ItemHandler. Each of these handlers delegates details to the next layer. Using the http
// package as an analogy, ProtocolHandler is similar to the wire-level HTTP protocol handling in
// http.Server and http.Transport. MessageHandler parses KMIP TTLV bytes into golang request and response structs.
// ItemHandler is a bit like http.ServeMux, routing particular KMIP operations to registered handlers.
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/ansel1/merry"
"github.com/gemalto/flume"
"github.com/gemalto/kmip-go/kmip14"
"github.com/gemalto/kmip-go/ttlv"
"github.com/google/uuid"
)
var serverLog = flume.New("kmip_server")
// Server serves KMIP protocol connections from a net.Listener. Because KMIP is a connection-oriented
// protocol, unlike HTTP, each connection ends up being serviced by a dedicated goroutine (rather than
// each request). For each KMIP connection, requests are processed serially. The handling
// of the request is delegated to the ProtocolHandler.
//
// Limitations:
//
// This implementation is functional (it can respond to KMIP requests), but incomplete. Some of the
// connection management features of the http package haven't been ported over, and also, there is
// currently no connection-context in which to store things like an authentication or session management.
// Since HTTP is an intrinsically stateless model, it makes sense for the http package to delegate session
// management to third party packages, but for KMIP, it would makes sense for there to be some first
// class support for a connection context.
//
// This package also only handles the binary TTLV encoding for now. It may make sense for this
// server to detect or support the XML and JSON encodings as well. It may also makes sense to support
// KMIP requests over HTTP, perhaps by adapting ProtocolHandler to an http.Handler or something.
type Server struct {
Handler ProtocolHandler
mu sync.Mutex
listeners map[*net.Listener]struct{}
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
}
// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
// and ListenAndServeTLS methods after a call to Shutdown or Close.
var ErrServerClosed = errors.New("http: Server closed")
// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.MessageHandler to reply to them.
//
// Serve always returns a non-nil error and closes l.
// After Shutdown or Close, the returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
//if fn := testHookServerServe; fn != nil {
// fn(srv, l) // call hook with unwrapped listener
//}
l = &onceCloseListener{Listener: l}
defer l.Close()
if !srv.trackListener(&l, true) {
return ErrServerClosed
}
defer srv.trackListener(&l, false)
var tempDelay time.Duration // how long to sleep on accept failure
baseCtx := context.Background() // base is always background, per Issue 16220
ctx := baseCtx
// ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
rw, e := l.Accept()
if e != nil {
if srv.shuttingDown() {
return ErrServerClosed
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
// srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
return e
}
tempDelay = 0
c := &conn{server: srv, rwc: rw}
// c.setState(c.rwc, StateNew) // before Serve can return
go c.serve(ctx)
}
}
// Close immediately closes all active net.Listeners and any
// connections in state StateNew, StateActive, or StateIdle. For a
// graceful shutdown, use Shutdown.
//
// Close does not attempt to close (and does not even know about)
// any hijacked connections, such as WebSockets.
//
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
defer srv.mu.Unlock()
// srv.closeDoneChanLocked()
err := srv.closeListenersLocked()
//for c := range srv.activeConn {
// c.rwc.Close()
// delete(srv.activeConn, c)
//}
return err
}
// shutdownPollInterval is how often we poll for quiescence
// during Server.Shutdown. This is lower during tests, to
// speed up tests.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
var shutdownPollInterval = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not attempt to close nor wait for hijacked
// connections such as WebSockets. The caller of Shutdown should
// separately notify such long-lived connections of shutdown and wait
// for them to close, if desired. See RegisterOnShutdown for a way to
// register shutdown notification functions.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
//srv.closeDoneChanLocked()
//for _, f := range srv.onShutdown {
// go f()
//}
srv.mu.Unlock()
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
return lnerr
//for {
// if srv.closeIdleConns() {
// return lnerr
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case <-ticker.C:
// }
//}
}
func (srv *Server) closeListenersLocked() error {
var err error
for ln := range srv.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
delete(srv.listeners, ln)
}
return err
}
// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
//
// It reports whether the server is still up (not Shutdown or Closed).
func (srv *Server) trackListener(ln *net.Listener, add bool) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.listeners == nil {
srv.listeners = make(map[*net.Listener]struct{})
}
if add {
if srv.shuttingDown() {
return false
}
srv.listeners[ln] = struct{}{}
} else {
delete(srv.listeners, ln)
}
return true
}
func (srv *Server) shuttingDown() bool {
return atomic.LoadInt32(&srv.inShutdown) != 0
}
type conn struct {
rwc net.Conn
remoteAddr string
localAddr string
tlsState *tls.ConnectionState
// cancelCtx cancels the connection-level context.
cancelCtx context.CancelFunc
// bufr reads from rwc.
bufr *bufio.Reader
dec *ttlv.Decoder
server *Server
}
func (c *conn) close() {
// TODO: http package has a buffered writer on the conn too, which is flushed here
_ = c.rwc.Close()
}
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
ctx = flume.WithLogger(ctx, serverLog)
ctx, cancelCtx := context.WithCancel(ctx)
c.cancelCtx = cancelCtx
c.remoteAddr = c.rwc.RemoteAddr().String()
c.localAddr = c.rwc.LocalAddr().String()
// ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
defer func() {
if err := recover(); err != nil {
// TODO: logging support
// if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
if e, ok := err.(error); ok {
fmt.Printf("kmip: panic serving %v: %v\n%s", c.remoteAddr, Details(e), buf)
} else {
fmt.Printf("kmip: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
// c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
cancelCtx()
// if !c.hijacked() {
c.close()
// c.setState(c.rwc, StateClosed)
//}
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
//if d := c.server.ReadTimeout; d != 0 {
// c.rwc.SetReadDeadline(time.Now().Add(d))
//}
//if d := c.server.WriteTimeout; d != 0 {
// c.rwc.SetWriteDeadline(time.Now().Add(d))
//}
if err := tlsConn.Handshake(); err != nil {
// TODO: logging support
fmt.Printf("kmip: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
// c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
//if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
// if fn := c.server.TLSNextProto[proto]; fn != nil {
// h := initNPNRequest{tlsConn, serverHandler{c.server}}
// fn(c.server, tlsConn, h)
// }
// return
//}
}
// TODO: do we really need instance pooling here? We expect KMIP connections to be long lasting
c.dec = ttlv.NewDecoder(c.rwc)
c.bufr = bufio.NewReader(c.rwc)
// c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
for {
w, err := c.readRequest(ctx)
//if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active.
//c.setState(c.rwc, StateActive)
//}
if err != nil {
if merry.Is(err, io.EOF) {
fmt.Println("client closed connection")
return
}
// TODO: do something with this error
panic(err)
//const errorHeaders= "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
//
//if err == errTooLarge {
// // Their HTTP client may or may not be
// // able to read this if we're
// // responding to them and hanging up
// // while they're still writing their
// // request. Undefined behavior.
// const publicErr= "431 Request Header Fields Too Large"
// fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
// c.closeWriteAndWait()
// return
//}
//if isCommonNetReadError(err) {
// return // don't reply
//}
//
//publicErr := "400 Bad Request"
//if v, ok := err.(badRequestError); ok {
// publicErr = publicErr + ": " + string(v)
//}
//
//fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
//return
}
// Expect 100 Continue support
//req := w.req
//if req.expectsContinue() {
// if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
// // Wrap the Body reader with one that replies on the connection
// req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
// }
//} else if req.Header.get("Expect") != "" {
// w.sendExpectationFailed()
// return
//}
// c.curReq.Store(w)
//if requestBodyRemains(req.Body) {
// registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
//} else {
// w.conn.r.startBackgroundRead()
//}
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
// But we're not going to implement HTTP pipelining because it
// was never deployed in the wild and the answer is HTTP/2.
h := c.server.Handler
if h == nil {
h = DefaultProtocolHandler
}
// var resp ResponseMessage
// err = c.server.MessageHandler.Handle(ctx, w, &resp)
// TODO: this cancelCtx() was created at the connection level, not the request level. Need to
// figure out how to handle connection vs request timeouts and cancels.
// cancelCtx()
// TODO: use recycled buffered writer
writer := bufio.NewWriter(c.rwc)
h.ServeKMIP(ctx, w, writer)
err = writer.Flush()
if err != nil {
// TODO: handle error
panic(err)
}
//serverHandler{c.server}.ServeHTTP(w, w.req)
//w.cancelCtx()
//if c.hijacked() {
// return
//}
//w.finishRequest()
//if !w.shouldReuseConnection() {
// if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
// c.closeWriteAndWait()
// }
// return
//}
//c.setState(c.rwc, StateIdle)
//c.curReq.Store((*response)(nil))
//if !w.conn.server.doKeepAlives() {
// // We're in shutdown mode. We might've replied
// // to the user without "Connection: close" and
// // they might think they can send another
// // request, but such is life with HTTP/1.1.
// return
//}
//
//if d := c.server.idleTimeout(); d != 0 {
// c.rwc.SetReadDeadline(time.Now().Add(d))
// if _, err := c.bufr.Peek(4); err != nil {
// return
// }
//}
//c.rwc.SetReadDeadline(time.Time{})
}
}
// Read next request from connection.
func (c *conn) readRequest(ctx context.Context) (w *Request, err error) {
//if c.hijacked() {
// return nil, ErrHijacked
//}
//var (
// wholeReqDeadline time.Time // or zero if none
// hdrDeadline time.Time // or zero if none
//)
//t0 := time.Now()
//if d := c.server.readHeaderTimeout(); d != 0 {
// hdrDeadline = t0.Add(d)
//}
//if d := c.server.ReadTimeout; d != 0 {
// wholeReqDeadline = t0.Add(d)
//}
//c.rwc.SetReadDeadline(hdrDeadline)
//if d := c.server.WriteTimeout; d != 0 {
// defer func() {
// c.rwc.SetWriteDeadline(time.Now().Add(d))
// }()
//}
//c.r.setReadLimit(c.server.initialReadLimitSize())
//if c.lastMethod == "POST" {
// RFC 7230 section 3 tolerance for old buggy clients.
//peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
//c.bufr.Discard(numLeadingCRorLF(peek))
//}
ttlvVal, err := c.dec.NextTTLV()
if err != nil {
return nil, err
}
//if err != nil {
//if c.r.hitReadLimit() {
// return nil, errTooLarge
//}
//}
// TODO: use pooling to recycle requests?
req := &Request{
TTLV: ttlvVal,
RemoteAddr: c.remoteAddr,
LocalAddr: c.localAddr,
TLS: c.tlsState,
}
// c.r.setInfiniteReadLimit()
// Adjust the read deadline if necessary.
//if !hdrDeadline.Equal(wholeReqDeadline) {
// c.rwc.SetReadDeadline(wholeReqDeadline)
//}
return req, nil
}
// Request represents a KMIP request.
type Request struct {
// TTLV will hold the entire body of the request.
TTLV ttlv.TTLV
Message *RequestMessage
CurrentItem *RequestBatchItem
DisallowExtraValues bool
// TLS holds the TLS state of the connection this request was received on.
TLS *tls.ConnectionState
RemoteAddr string
LocalAddr string
IDPlaceholder string
decoder *ttlv.Decoder
}
// coerceToTTLV attempts to coerce an interface value to TTLV.
// In most production scenarios, this is intended to be used in
// places where the value is already a TTLV, and just needs to be
// type cast. If v is not TTLV, it will be marshaled. This latter
// behavior is slow, so it should be used only in tests.
func coerceToTTLV(v interface{}) (ttlv.TTLV, error) {
switch t := v.(type) {
case nil:
return nil, nil
case ttlv.TTLV:
return t, nil
default:
return ttlv.Marshal(v)
}
}
// Unmarshal unmarshals ttlv into structures. Handlers should prefer this
// method over than their own Decoders or Unmarshal(). This method
// enforces rules about whether extra fields are allowed, and reuses
// buffers for efficiency.
func (r *Request) Unmarshal(ttlv ttlv.TTLV, into interface{}) error {
if len(ttlv) == 0 {
return nil
}
r.decoder.Reset(bytes.NewReader(ttlv))
return r.decoder.Decode(into)
}
func (r *Request) DecodePayload(v interface{}) error {
if r.CurrentItem == nil {
return nil
}
ttlvVal, err := coerceToTTLV(r.CurrentItem.RequestPayload)
if err != nil {
return err
}
return r.Unmarshal(ttlvVal, v)
}
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
type ResponseWriter interface {
io.Writer
}
// ProtocolHandler is responsible for handling raw requests read off the wire. The
// *Request object will only have TTLV field populated. The response should
// be written directly to the ResponseWriter.
//
// The default implemention of ProtocolHandler is StandardProtocolHandler.
type ProtocolHandler interface {
ServeKMIP(ctx context.Context, req *Request, resp ResponseWriter)
}
// MessageHandler handles KMIP requests which have already be decoded. The *Request
// object's Message field will be populated from the decoded TTLV. The *Response
// object will always be non-nil, and its ResponseHeader will be populated. The
// MessageHandler usually shouldn't modify the ResponseHeader: the ProtocolHandler
// is responsible for the header. The MessageHandler just needs to populate
// the response batch items.
//
// The default implementation of MessageHandler is OperationMux.
type MessageHandler interface {
HandleMessage(ctx context.Context, req *Request, resp *Response)
}
// ItemHandler handles a single batch item in a KMIP request. The *Request
// object's CurrentItem field will be populated with item to be handled.
type ItemHandler interface {
HandleItem(ctx context.Context, req *Request) (item *ResponseBatchItem, err error)
}
type ProtocolHandlerFunc func(context.Context, *Request, ResponseWriter)
func (f ProtocolHandlerFunc) ServeKMIP(ctx context.Context, r *Request, w ResponseWriter) {
f(ctx, r, w)
}
type MessageHandlerFunc func(context.Context, *Request, *Response)
func (f MessageHandlerFunc) HandleMessage(ctx context.Context, req *Request, resp *Response) {
f(ctx, req, resp)
}
type ItemHandlerFunc func(context.Context, *Request) (*ResponseBatchItem, error)
func (f ItemHandlerFunc) HandleItem(ctx context.Context, req *Request) (item *ResponseBatchItem, err error) {
return f(ctx, req)
}
var DefaultProtocolHandler = &StandardProtocolHandler{
MessageHandler: DefaultOperationMux,
ProtocolVersion: ProtocolVersion{
ProtocolVersionMajor: 1,
ProtocolVersionMinor: 4,
},
}
var DefaultOperationMux = &OperationMux{}
// StandardProtocolHandler is the default ProtocolHandler implementation. It
// handles decoding the request and encoding the response, as well as protocol
// level tasks like version negotiation and correlation values.
//
// It delegates handling of the request to a MessageHandler.
type StandardProtocolHandler struct {
ProtocolVersion ProtocolVersion
MessageHandler MessageHandler
LogTraffic bool
}
func (h *StandardProtocolHandler) parseMessage(ctx context.Context, req *Request) error {
ttlvV := req.TTLV
if err := ttlvV.Valid(); err != nil {
return merry.Prepend(err, "invalid ttlv")
}
if ttlvV.Tag() != kmip14.TagRequestMessage {
return merry.Errorf("invalid tag: expected RequestMessage, was %s", ttlvV.Tag().String())
}
var message RequestMessage
err := ttlv.Unmarshal(ttlvV, &message)
if err != nil {
return merry.Prepend(err, "failed to parse message")
}
req.Message = &message
return nil
}
var responsePool = sync.Pool{}
type Response struct {
ResponseMessage
buf bytes.Buffer
enc *ttlv.Encoder
}
func newResponse() *Response {
v := responsePool.Get()
if v != nil {
r := v.(*Response)
r.reset()
return r
}
r := Response{}
r.enc = ttlv.NewEncoder(&r.buf)
return &r
}
func releaseResponse(r *Response) {
responsePool.Put(r)
}
func (r *Response) reset() {
r.BatchItem = nil
r.ResponseMessage = ResponseMessage{}
r.buf.Reset()
}
func (r *Response) Bytes() []byte {
r.buf.Reset()
err := r.enc.Encode(&r.ResponseMessage)
if err != nil {
panic(err)
}
return r.buf.Bytes()
}
func (r *Response) errorResponse(reason kmip14.ResultReason, msg string) {
r.BatchItem = []ResponseBatchItem{
{
ResultStatus: kmip14.ResultStatusOperationFailed,
ResultReason: reason,
ResultMessage: msg,
},
}
}
func (h *StandardProtocolHandler) handleRequest(ctx context.Context, req *Request, resp *Response) (logger flume.Logger) {
// create a server correlation value, which is like a unique transaction ID
scv := uuid.New().String()
// create a logger for the transaction, seeded with the scv
logger = flume.FromContext(ctx).With("scv", scv)
// attach the logger to the context, so it is available to the handling chain
ctx = flume.WithLogger(ctx, logger)
// TODO: it's unclear how the full protocol negogiation is supposed to work
// should server be pinned to a particular version? Or should we try and negogiate a common version?
resp.ResponseHeader.ProtocolVersion = h.ProtocolVersion
resp.ResponseHeader.TimeStamp = time.Now()
resp.ResponseHeader.BatchCount = len(resp.BatchItem)
resp.ResponseHeader.ServerCorrelationValue = scv
if err := h.parseMessage(ctx, req); err != nil {
resp.errorResponse(kmip14.ResultReasonInvalidMessage, err.Error())
return
}
ccv := req.Message.RequestHeader.ClientCorrelationValue
// add the client correlation value to the logging context. This value uniquely
// identifies the client, and is supposed to be included in server logs
logger = logger.With("ccv", ccv)
ctx = flume.WithLogger(ctx, logger)
resp.ResponseHeader.ClientCorrelationValue = req.Message.RequestHeader.ClientCorrelationValue
clientMajorVersion := req.Message.RequestHeader.ProtocolVersion.ProtocolVersionMajor
if clientMajorVersion != h.ProtocolVersion.ProtocolVersionMajor {
resp.errorResponse(kmip14.ResultReasonInvalidMessage,
fmt.Sprintf("mismatched protocol versions, client: %d, server: %d", clientMajorVersion, h.ProtocolVersion.ProtocolVersionMajor))
return
}
// set a flag hinting to handlers that extra fields should not be tolerated when
// unmarshaling payloads. According to spec, if server and client protocol version
// minor versions match, then extra fields should cause an error. Not sure how to enforce
// this in this higher level handler, since we (the protocol/message handlers) don't unmarshal the payload.
// That's done by a particular item handler.
req.DisallowExtraValues = req.Message.RequestHeader.ProtocolVersion.ProtocolVersionMinor == h.ProtocolVersion.ProtocolVersionMinor
req.decoder = ttlv.NewDecoder(nil)
req.decoder.DisallowExtraValues = req.DisallowExtraValues
h.MessageHandler.HandleMessage(ctx, req, resp)
resp.ResponseHeader.BatchCount = len(resp.BatchItem)
respTTLV := resp.Bytes()
if req.Message.RequestHeader.MaximumResponseSize > 0 && len(respTTLV) > req.Message.RequestHeader.MaximumResponseSize {
// new error resp
resp.errorResponse(kmip14.ResultReasonResponseTooLarge, "")
respTTLV = resp.Bytes()
}
return
}
func (h *StandardProtocolHandler) ServeKMIP(ctx context.Context, req *Request, writer ResponseWriter) {
// we precreate the response object and pass it down to handlers, because due
// the guidance in the spec on the Maximum Response Size, it will be necessary
// for handlers to recalculate the response size after each batch item, which
// requires re-encoding the entire response. Seems inefficient.
resp := newResponse()
logger := h.handleRequest(ctx, req, resp)
var err error
if h.LogTraffic {
ttlvV := resp.Bytes()
logger.Debug("traffic log", "request", req.TTLV.String(), "response", ttlv.TTLV(ttlvV).String())
_, err = writer.Write(ttlvV)
} else {
_, err = resp.buf.WriteTo(writer)
}
if err != nil {
panic(err)
}
releaseResponse(resp)
}
func (r *ResponseMessage) addFailure(reason kmip14.ResultReason, msg string) {
if msg == "" {
msg = reason.String()
}
r.BatchItem = append(r.BatchItem, ResponseBatchItem{
ResultStatus: kmip14.ResultStatusOperationFailed,
ResultReason: reason,
ResultMessage: msg,
})
}
// OperationMux is an implementation of MessageHandler which handles each batch item in the request
// by routing the operation to an ItemHandler. The ItemHandler performs the operation, and returns
// either a *ResponseBatchItem, or an error. If it returns an error, the error is passed to
// ErrorHandler, which converts it into a error *ResponseBatchItem. OperationMux handles correlating
// items in the request to items in the response.
type OperationMux struct {
mu sync.RWMutex
handlers map[kmip14.Operation]ItemHandler
// ErrorHandler defaults to the DefaultErrorHandler.
ErrorHandler ErrorHandler
}
// ErrorHandler converts a golang error into a *ResponseBatchItem (which should hold information
// about the error to convey back to the client).
type ErrorHandler interface {
HandleError(err error) *ResponseBatchItem
}
type ErrorHandlerFunc func(err error) *ResponseBatchItem
func (f ErrorHandlerFunc) HandleError(err error) *ResponseBatchItem {
return f(err)
}
// DefaultErrorHandler tries to map errors to ResultReasons.
var DefaultErrorHandler = ErrorHandlerFunc(func(err error) *ResponseBatchItem {
reason := GetResultReason(err)
if reason == kmip14.ResultReason(0) {
// error not handled
return nil
}
// prefer user message, but fall back on message
msg := merry.UserMessage(err)
if msg == "" {
msg = merry.Message(err)
}
return newFailedResponseBatchItem(reason, msg)
})
func newFailedResponseBatchItem(reason kmip14.ResultReason, msg string) *ResponseBatchItem {
return &ResponseBatchItem{
ResultStatus: kmip14.ResultStatusOperationFailed,
ResultReason: reason,
ResultMessage: msg,
}
}
func (m *OperationMux) bi(ctx context.Context, req *Request, reqItem *RequestBatchItem) *ResponseBatchItem {
req.CurrentItem = reqItem
h := m.handlerForOp(reqItem.Operation)
if h == nil {
return newFailedResponseBatchItem(kmip14.ResultReasonOperationNotSupported, "")
}
resp, err := h.HandleItem(ctx, req)
if err != nil {
eh := m.ErrorHandler
if eh == nil {
eh = DefaultErrorHandler
}
resp = eh.HandleError(err)
if resp == nil {
// errors which don't convert just panic
panic(err)
}
}
return resp
}
func (m *OperationMux) HandleMessage(ctx context.Context, req *Request, resp *Response) {
for i := range req.Message.BatchItem {
reqItem := &req.Message.BatchItem[i]
respItem := m.bi(ctx, req, reqItem)
respItem.Operation = reqItem.Operation
respItem.UniqueBatchItemID = reqItem.UniqueBatchItemID
resp.BatchItem = append(resp.BatchItem, *respItem)
}
}
func (m *OperationMux) Handle(op kmip14.Operation, handler ItemHandler) {
m.mu.Lock()
defer m.mu.Unlock()
if m.handlers == nil {
m.handlers = map[kmip14.Operation]ItemHandler{}
}
m.handlers[op] = handler
}
func (m *OperationMux) handlerForOp(op kmip14.Operation) ItemHandler {
m.mu.RLock()
defer m.mu.RUnlock()
return m.handlers[op]
}
func (m *OperationMux) missingHandler(ctx context.Context, req *Request, resp *ResponseMessage) error {
resp.addFailure(kmip14.ResultReasonOperationNotSupported, "")
return nil
}