Fresh dep ensure

This commit is contained in:
Mike Cronce
2018-11-26 13:23:56 -05:00
parent 93cb8a04d7
commit 407478ab9a
9016 changed files with 551394 additions and 279685 deletions

View File

@ -0,0 +1,140 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"sync"
"time"
)
const (
// bdpLimit is the maximum value the flow control windows
// will be increased to.
bdpLimit = (1 << 20) * 4
// alpha is a constant factor used to keep a moving average
// of RTTs.
alpha = 0.9
// If the current bdp sample is greater than or equal to
// our beta * our estimated bdp and the current bandwidth
// sample is the maximum bandwidth observed so far, we
// increase our bbp estimate by a factor of gamma.
beta = 0.66
// To put our bdp to be smaller than or equal to twice the real BDP,
// we should multiply our current sample with 4/3, however to round things out
// we use 2 as the multiplication factor.
gamma = 2
)
// Adding arbitrary data to ping so that its ack can be identified.
// Easter-egg: what does the ping message say?
var bdpPing = &ping{data: [8]byte{2, 4, 16, 16, 9, 14, 7, 7}}
type bdpEstimator struct {
// sentAt is the time when the ping was sent.
sentAt time.Time
mu sync.Mutex
// bdp is the current bdp estimate.
bdp uint32
// sample is the number of bytes received in one measurement cycle.
sample uint32
// bwMax is the maximum bandwidth noted so far (bytes/sec).
bwMax float64
// bool to keep track of the beginning of a new measurement cycle.
isSent bool
// Callback to update the window sizes.
updateFlowControl func(n uint32)
// sampleCount is the number of samples taken so far.
sampleCount uint64
// round trip time (seconds)
rtt float64
}
// timesnap registers the time bdp ping was sent out so that
// network rtt can be calculated when its ack is received.
// It is called (by controller) when the bdpPing is
// being written on the wire.
func (b *bdpEstimator) timesnap(d [8]byte) {
if bdpPing.data != d {
return
}
b.sentAt = time.Now()
}
// add adds bytes to the current sample for calculating bdp.
// It returns true only if a ping must be sent. This can be used
// by the caller (handleData) to make decision about batching
// a window update with it.
func (b *bdpEstimator) add(n uint32) bool {
b.mu.Lock()
defer b.mu.Unlock()
if b.bdp == bdpLimit {
return false
}
if !b.isSent {
b.isSent = true
b.sample = n
b.sentAt = time.Time{}
b.sampleCount++
return true
}
b.sample += n
return false
}
// calculate is called when an ack for a bdp ping is received.
// Here we calculate the current bdp and bandwidth sample and
// decide if the flow control windows should go up.
func (b *bdpEstimator) calculate(d [8]byte) {
// Check if the ping acked for was the bdp ping.
if bdpPing.data != d {
return
}
b.mu.Lock()
rttSample := time.Since(b.sentAt).Seconds()
if b.sampleCount < 10 {
// Bootstrap rtt with an average of first 10 rtt samples.
b.rtt += (rttSample - b.rtt) / float64(b.sampleCount)
} else {
// Heed to the recent past more.
b.rtt += (rttSample - b.rtt) * float64(alpha)
}
b.isSent = false
// The number of bytes accumulated so far in the sample is smaller
// than or equal to 1.5 times the real BDP on a saturated connection.
bwCurrent := float64(b.sample) / (b.rtt * float64(1.5))
if bwCurrent > b.bwMax {
b.bwMax = bwCurrent
}
// If the current sample (which is smaller than or equal to the 1.5 times the real BDP) is
// greater than or equal to 2/3rd our perceived bdp AND this is the maximum bandwidth seen so far, we
// should update our perception of the network BDP.
if float64(b.sample) >= beta*float64(b.bdp) && bwCurrent == b.bwMax && b.bdp != bdpLimit {
sampleFloat := float64(b.sample)
b.bdp = uint32(gamma * sampleFloat)
if b.bdp > bdpLimit {
b.bdp = bdpLimit
}
bdp := b.bdp
b.mu.Unlock()
b.updateFlowControl(bdp)
return
}
b.mu.Unlock()
}

View File

@ -0,0 +1,852 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"bytes"
"fmt"
"runtime"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
e.SetMaxDynamicTableSizeLimit(v)
}
type itemNode struct {
it interface{}
next *itemNode
}
type itemList struct {
head *itemNode
tail *itemNode
}
func (il *itemList) enqueue(i interface{}) {
n := &itemNode{it: i}
if il.tail == nil {
il.head, il.tail = n, n
return
}
il.tail.next = n
il.tail = n
}
// peek returns the first item in the list without removing it from the
// list.
func (il *itemList) peek() interface{} {
return il.head.it
}
func (il *itemList) dequeue() interface{} {
if il.head == nil {
return nil
}
i := il.head.it
il.head = il.head.next
if il.head == nil {
il.tail = nil
}
return i
}
func (il *itemList) dequeueAll() *itemNode {
h := il.head
il.head, il.tail = nil, nil
return h
}
func (il *itemList) isEmpty() bool {
return il.head == nil
}
// The following defines various control items which could flow through
// the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc.
// registerStream is used to register an incoming stream with loopy writer.
type registerStream struct {
streamID uint32
wq *writeQuota
}
// headerFrame is also used to register stream on the client-side.
type headerFrame struct {
streamID uint32
hf []hpack.HeaderField
endStream bool // Valid on server side.
initStream func(uint32) (bool, error) // Used only on the client side.
onWrite func()
wq *writeQuota // write quota for the stream created.
cleanup *cleanupStream // Valid on the server side.
onOrphaned func(error) // Valid on client-side
}
type cleanupStream struct {
streamID uint32
rst bool
rstCode http2.ErrCode
onWrite func()
}
type dataFrame struct {
streamID uint32
endStream bool
h []byte
d []byte
// onEachWrite is called every time
// a part of d is written out.
onEachWrite func()
}
type incomingWindowUpdate struct {
streamID uint32
increment uint32
}
type outgoingWindowUpdate struct {
streamID uint32
increment uint32
}
type incomingSettings struct {
ss []http2.Setting
}
type outgoingSettings struct {
ss []http2.Setting
}
type incomingGoAway struct {
}
type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn bool
}
type ping struct {
ack bool
data [8]byte
}
type outFlowControlSizeRequest struct {
resp chan uint32
}
type outStreamState int
const (
active outStreamState = iota
empty
waitingOnStreamQuota
)
type outStream struct {
id uint32
state outStreamState
itl *itemList
bytesOutStanding int
wq *writeQuota
next *outStream
prev *outStream
}
func (s *outStream) deleteSelf() {
if s.prev != nil {
s.prev.next = s.next
}
if s.next != nil {
s.next.prev = s.prev
}
s.next, s.prev = nil, nil
}
type outStreamList struct {
// Following are sentinel objects that mark the
// beginning and end of the list. They do not
// contain any item lists. All valid objects are
// inserted in between them.
// This is needed so that an outStream object can
// deleteSelf() in O(1) time without knowing which
// list it belongs to.
head *outStream
tail *outStream
}
func newOutStreamList() *outStreamList {
head, tail := new(outStream), new(outStream)
head.next = tail
tail.prev = head
return &outStreamList{
head: head,
tail: tail,
}
}
func (l *outStreamList) enqueue(s *outStream) {
e := l.tail.prev
e.next = s
s.prev = e
s.next = l.tail
l.tail.prev = s
}
// remove from the beginning of the list.
func (l *outStreamList) dequeue() *outStream {
b := l.head.next
if b == l.tail {
return nil
}
b.deleteSelf()
return b
}
// controlBuffer is a way to pass information to loopy.
// Information is passed as specific struct types called control frames.
// A control frame not only represents data, messages or headers to be sent out
// but can also be used to instruct loopy to update its internal state.
// It shouldn't be confused with an HTTP2 frame, although some of the control frames
// like dataFrame and headerFrame do go out on wire as HTTP2 frames.
type controlBuffer struct {
ch chan struct{}
done <-chan struct{}
mu sync.Mutex
consumerWaiting bool
list *itemList
err error
}
func newControlBuffer(done <-chan struct{}) *controlBuffer {
return &controlBuffer{
ch: make(chan struct{}, 1),
list: &itemList{},
done: done,
}
}
func (c *controlBuffer) put(it interface{}) error {
_, err := c.executeAndPut(nil, it)
return err
}
func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) {
var wakeUp bool
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return false, c.err
}
if f != nil {
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
}
if c.consumerWaiting {
wakeUp = true
c.consumerWaiting = false
}
c.list.enqueue(it)
c.mu.Unlock()
if wakeUp {
select {
case c.ch <- struct{}{}:
default:
}
}
return true, nil
}
// Note argument f should never be nil.
func (c *controlBuffer) execute(f func(it interface{}) bool, it interface{}) (bool, error) {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return false, c.err
}
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
c.mu.Unlock()
return true, nil
}
func (c *controlBuffer) get(block bool) (interface{}, error) {
for {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return nil, c.err
}
if !c.list.isEmpty() {
h := c.list.dequeue()
c.mu.Unlock()
return h, nil
}
if !block {
c.mu.Unlock()
return nil, nil
}
c.consumerWaiting = true
c.mu.Unlock()
select {
case <-c.ch:
case <-c.done:
c.finish()
return nil, ErrConnClosing
}
}
}
func (c *controlBuffer) finish() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return
}
c.err = ErrConnClosing
// There may be headers for streams in the control buffer.
// These streams need to be cleaned out since the transport
// is still not aware of these yet.
for head := c.list.dequeueAll(); head != nil; head = head.next {
hdr, ok := head.it.(*headerFrame)
if !ok {
continue
}
if hdr.onOrphaned != nil { // It will be nil on the server-side.
hdr.onOrphaned(ErrConnClosing)
}
}
c.mu.Unlock()
}
type side int
const (
clientSide side = iota
serverSide
)
// Loopy receives frames from the control buffer.
// Each frame is handled individually; most of the work done by loopy goes
// into handling data frames. Loopy maintains a queue of active streams, and each
// stream maintains a queue of data frames; as loopy receives data frames
// it gets added to the queue of the relevant stream.
// Loopy goes over this list of active streams by processing one node every iteration,
// thereby closely resemebling to a round-robin scheduling over all streams. While
// processing a stream, loopy writes out data bytes from this stream capped by the min
// of http2MaxFrameLen, connection-level flow control and stream-level flow control.
type loopyWriter struct {
side side
cbuf *controlBuffer
sendQuota uint32
oiws uint32 // outbound initial window size.
// estdStreams is map of all established streams that are not cleaned-up yet.
// On client-side, this is all streams whose headers were sent out.
// On server-side, this is all streams whose headers were received.
estdStreams map[uint32]*outStream // Established streams.
// activeStreams is a linked-list of all streams that have data to send and some
// stream-level flow control quota.
// Each of these streams internally have a list of data items(and perhaps trailers
// on the server-side) to be sent out.
activeStreams *outStreamList
framer *framer
hBuf *bytes.Buffer // The buffer for HPACK encoding.
hEnc *hpack.Encoder // HPACK encoder.
bdpEst *bdpEstimator
draining bool
// Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error)
}
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter {
var buf bytes.Buffer
l := &loopyWriter{
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
}
return l
}
const minBatchSize = 1000
// run should be run in a separate goroutine.
// It reads control frames from controlBuf and processes them by:
// 1. Updating loopy's internal state, or/and
// 2. Writing out HTTP2 frames on the wire.
//
// Loopy keeps all active streams with data to send in a linked-list.
// All streams in the activeStreams linked-list must have both:
// 1. Data to send, and
// 2. Stream level flow control quota available.
//
// In each iteration of run loop, other than processing the incoming control
// frame, loopy calls processData, which processes one node from the activeStreams linked-list.
// This results in writing of HTTP2 frames into an underlying write buffer.
// When there's no more control frames to read from controlBuf, loopy flushes the write buffer.
// As an optimization, to increase the batch size for each flush, loopy yields the processor, once
// if the batch size is too low to give stream goroutines a chance to fill it up.
func (l *loopyWriter) run() (err error) {
defer func() {
if err == ErrConnClosing {
// Don't log ErrConnClosing as error since it happens
// 1. When the connection is closed by some other known issue.
// 2. User closed the connection.
// 3. A graceful close of connection.
infof("transport: loopyWriter.run returning. %v", err)
err = nil
}
}()
for {
it, err := l.cbuf.get(true)
if err != nil {
return err
}
if err = l.handle(it); err != nil {
return err
}
if _, err = l.processData(); err != nil {
return err
}
gosched := true
hasdata:
for {
it, err := l.cbuf.get(false)
if err != nil {
return err
}
if it != nil {
if err = l.handle(it); err != nil {
return err
}
if _, err = l.processData(); err != nil {
return err
}
continue hasdata
}
isEmpty, err := l.processData()
if err != nil {
return err
}
if !isEmpty {
continue hasdata
}
if gosched {
gosched = false
if l.framer.writer.offset < minBatchSize {
runtime.Gosched()
continue hasdata
}
}
l.framer.writer.Flush()
break hasdata
}
}
}
func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error {
return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment)
}
func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error {
// Otherwise update the quota.
if w.streamID == 0 {
l.sendQuota += w.increment
return nil
}
// Find the stream and update it.
if str, ok := l.estdStreams[w.streamID]; ok {
str.bytesOutStanding -= int(w.increment)
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota {
str.state = active
l.activeStreams.enqueue(str)
return nil
}
}
return nil
}
func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error {
return l.framer.fr.WriteSettings(s.ss...)
}
func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error {
if err := l.applySettings(s.ss); err != nil {
return err
}
return l.framer.fr.WriteSettingsAck()
}
func (l *loopyWriter) registerStreamHandler(h *registerStream) error {
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
l.estdStreams[h.streamID] = str
return nil
}
func (l *loopyWriter) headerHandler(h *headerFrame) error {
if l.side == serverSide {
str, ok := l.estdStreams[h.streamID]
if !ok {
warningf("transport: loopy doesn't recognize the stream: %d", h.streamID)
return nil
}
// Case 1.A: Server is responding back with headers.
if !h.endStream {
return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite)
}
// else: Case 1.B: Server wants to close stream.
if str.state != empty { // either active or waiting on stream quota.
// add it str's list of items.
str.itl.enqueue(h)
return nil
}
if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
return err
}
return l.cleanupStreamHandler(h.cleanup)
}
// Case 2: Client wants to originate stream.
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
str.itl.enqueue(h)
return l.originateStream(str)
}
func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
sendPing, err := hdr.initStream(str.id)
if err != nil {
if err == ErrConnClosing {
return err
}
// Other errors(errStreamDrain) need not close transport.
return nil
}
if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
return err
}
l.estdStreams[str.id] = str
if sendPing {
return l.pingHandler(&ping{data: [8]byte{}})
}
return nil
}
func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error {
if onWrite != nil {
onWrite()
}
l.hBuf.Reset()
for _, f := range hf {
if err := l.hEnc.WriteField(f); err != nil {
warningf("transport: loopyWriter.writeHeader encountered error while encoding headers:", err)
}
}
var (
err error
endHeaders, first bool
)
first = true
for !endHeaders {
size := l.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: l.hBuf.Next(size),
EndStream: endStream,
EndHeaders: endHeaders,
})
} else {
err = l.framer.fr.WriteContinuation(
streamID,
endHeaders,
l.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
return nil
}
func (l *loopyWriter) preprocessData(df *dataFrame) error {
str, ok := l.estdStreams[df.streamID]
if !ok {
return nil
}
// If we got data for a stream it means that
// stream was originated and the headers were sent out.
str.itl.enqueue(df)
if str.state == empty {
str.state = active
l.activeStreams.enqueue(str)
}
return nil
}
func (l *loopyWriter) pingHandler(p *ping) error {
if !p.ack {
l.bdpEst.timesnap(p.data)
}
return l.framer.fr.WritePing(p.ack, p.data)
}
func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequest) error {
o.resp <- l.sendQuota
return nil
}
func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
c.onWrite()
if str, ok := l.estdStreams[c.streamID]; ok {
// On the server side it could be a trailers-only response or
// a RST_STREAM before stream initialization thus the stream might
// not be established yet.
delete(l.estdStreams, c.streamID)
str.deleteSelf()
}
if c.rst { // If RST_STREAM needs to be sent.
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
return err
}
}
if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
return ErrConnClosing
}
return nil
}
func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
if l.side == clientSide {
l.draining = true
if len(l.estdStreams) == 0 {
return ErrConnClosing
}
}
return nil
}
func (l *loopyWriter) goAwayHandler(g *goAway) error {
// Handling of outgoing GoAway is very specific to side.
if l.ssGoAwayHandler != nil {
draining, err := l.ssGoAwayHandler(g)
if err != nil {
return err
}
l.draining = draining
}
return nil
}
func (l *loopyWriter) handle(i interface{}) error {
switch i := i.(type) {
case *incomingWindowUpdate:
return l.incomingWindowUpdateHandler(i)
case *outgoingWindowUpdate:
return l.outgoingWindowUpdateHandler(i)
case *incomingSettings:
return l.incomingSettingsHandler(i)
case *outgoingSettings:
return l.outgoingSettingsHandler(i)
case *headerFrame:
return l.headerHandler(i)
case *registerStream:
return l.registerStreamHandler(i)
case *cleanupStream:
return l.cleanupStreamHandler(i)
case *incomingGoAway:
return l.incomingGoAwayHandler(i)
case *dataFrame:
return l.preprocessData(i)
case *ping:
return l.pingHandler(i)
case *goAway:
return l.goAwayHandler(i)
case *outFlowControlSizeRequest:
return l.outFlowControlSizeRequestHandler(i)
default:
return fmt.Errorf("transport: unknown control message type %T", i)
}
}
func (l *loopyWriter) applySettings(ss []http2.Setting) error {
for _, s := range ss {
switch s.ID {
case http2.SettingInitialWindowSize:
o := l.oiws
l.oiws = s.Val
if o < l.oiws {
// If the new limit is greater make all depleted streams active.
for _, stream := range l.estdStreams {
if stream.state == waitingOnStreamQuota {
stream.state = active
l.activeStreams.enqueue(stream)
}
}
}
case http2.SettingHeaderTableSize:
updateHeaderTblSize(l.hEnc, s.Val)
}
}
return nil
}
// processData removes the first stream from active streams, writes out at most 16KB
// of its data and then puts it at the end of activeStreams if there's still more data
// to be sent and stream has some stream-level flow control.
func (l *loopyWriter) processData() (bool, error) {
if l.sendQuota == 0 {
return true, nil
}
str := l.activeStreams.dequeue() // Remove the first stream.
if str == nil {
return true, nil
}
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
// A data item is represented by a dataFrame, since it later translates into
// multiple HTTP2 data frames.
// Every dataFrame has two buffers; h that keeps grpc-message header and d that is acutal data.
// As an optimization to keep wire traffic low, data from d is copied to h to make as big as the
// maximum possilbe HTTP2 frame size.
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err
}
str.itl.dequeue() // remove the empty data item from stream
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
return false, err
}
if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
return false, nil
}
} else {
l.activeStreams.enqueue(str)
}
return false, nil
}
var (
idx int
buf []byte
)
if len(dataItem.h) != 0 { // data header has not been written out yet.
buf = dataItem.h
} else {
idx = 1
buf = dataItem.d
}
size := http2MaxFrameLen
if len(buf) < size {
size = len(buf)
}
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
str.state = waitingOnStreamQuota
return false, nil
} else if strQuota < size {
size = strQuota
}
if l.sendQuota < uint32(size) { // connection-level flow control.
size = int(l.sendQuota)
}
// Now that outgoing flow controls are checked we can replenish str's write quota
str.wq.replenish(size)
var endStream bool
// If this is the last data message on this stream and all of it can be written in this iteration.
if dataItem.endStream && size == len(buf) {
// buf contains either data or it contains header but data is empty.
if idx == 1 || len(dataItem.d) == 0 {
endStream = true
}
}
if dataItem.onEachWrite != nil {
dataItem.onEachWrite()
}
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil {
return false, err
}
buf = buf[size:]
str.bytesOutStanding += size
l.sendQuota -= uint32(size)
if idx == 0 {
dataItem.h = buf
} else {
dataItem.d = buf
}
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out.
str.itl.dequeue()
}
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers.
if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
return false, err
}
if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
return false, err
}
} else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota.
str.state = waitingOnStreamQuota
} else { // Otherwise add it back to the list of active streams.
l.activeStreams.enqueue(str)
}
return false, nil
}

View File

@ -0,0 +1,49 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"math"
"time"
)
const (
// The default value of flow control window size in HTTP2 spec.
defaultWindowSize = 65535
// The initial window size for flow control.
initialWindowSize = defaultWindowSize // for an RPC
infinity = time.Duration(math.MaxInt64)
defaultClientKeepaliveTime = infinity
defaultClientKeepaliveTimeout = 20 * time.Second
defaultMaxStreamsClient = 100
defaultMaxConnectionIdle = infinity
defaultMaxConnectionAge = infinity
defaultMaxConnectionAgeGrace = infinity
defaultServerKeepaliveTime = 2 * time.Hour
defaultServerKeepaliveTimeout = 20 * time.Second
defaultKeepalivePolicyMinTime = 5 * time.Minute
// max window limit set by HTTP2 Specs.
maxWindowSize = math.MaxInt32
// defaultWriteQuota is the default value for number of data
// bytes that each stream can schedule before some of it being
// flushed out.
defaultWriteQuota = 64 * 1024
defaultClientMaxHeaderListSize = uint32(16 << 20)
defaultServerMaxHeaderListSize = uint32(16 << 20)
)

View File

@ -0,0 +1,218 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"fmt"
"math"
"sync"
"sync/atomic"
)
// writeQuota is a soft limit on the amount of data a stream can
// schedule before some of it is written out.
type writeQuota struct {
quota int32
// get waits on read from when quota goes less than or equal to zero.
// replenish writes on it when quota goes positive again.
ch chan struct{}
// done is triggered in error case.
done <-chan struct{}
// replenish is called by loopyWriter to give quota back to.
// It is implemented as a field so that it can be updated
// by tests.
replenish func(n int)
}
func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota {
w := &writeQuota{
quota: sz,
ch: make(chan struct{}, 1),
done: done,
}
w.replenish = w.realReplenish
return w
}
func (w *writeQuota) get(sz int32) error {
for {
if atomic.LoadInt32(&w.quota) > 0 {
atomic.AddInt32(&w.quota, -sz)
return nil
}
select {
case <-w.ch:
continue
case <-w.done:
return errStreamDone
}
}
}
func (w *writeQuota) realReplenish(n int) {
sz := int32(n)
a := atomic.AddInt32(&w.quota, sz)
b := a - sz
if b <= 0 && a > 0 {
select {
case w.ch <- struct{}{}:
default:
}
}
}
type trInFlow struct {
limit uint32
unacked uint32
effectiveWindowSize uint32
}
func (f *trInFlow) newLimit(n uint32) uint32 {
d := n - f.limit
f.limit = n
f.updateEffectiveWindowSize()
return d
}
func (f *trInFlow) onData(n uint32) uint32 {
f.unacked += n
if f.unacked >= f.limit/4 {
w := f.unacked
f.unacked = 0
f.updateEffectiveWindowSize()
return w
}
f.updateEffectiveWindowSize()
return 0
}
func (f *trInFlow) reset() uint32 {
w := f.unacked
f.unacked = 0
f.updateEffectiveWindowSize()
return w
}
func (f *trInFlow) updateEffectiveWindowSize() {
atomic.StoreUint32(&f.effectiveWindowSize, f.limit-f.unacked)
}
func (f *trInFlow) getSize() uint32 {
return atomic.LoadUint32(&f.effectiveWindowSize)
}
// TODO(mmukhi): Simplify this code.
// inFlow deals with inbound flow control
type inFlow struct {
mu sync.Mutex
// The inbound flow control limit for pending data.
limit uint32
// pendingData is the overall data which have been received but not been
// consumed by applications.
pendingData uint32
// The amount of data the application has consumed but grpc has not sent
// window update for them. Used to reduce window update frequency.
pendingUpdate uint32
// delta is the extra window update given by receiver when an application
// is reading data bigger in size than the inFlow limit.
delta uint32
}
// newLimit updates the inflow window to a new value n.
// It assumes that n is always greater than the old limit.
func (f *inFlow) newLimit(n uint32) uint32 {
f.mu.Lock()
d := n - f.limit
f.limit = n
f.mu.Unlock()
return d
}
func (f *inFlow) maybeAdjust(n uint32) uint32 {
if n > uint32(math.MaxInt32) {
n = uint32(math.MaxInt32)
}
f.mu.Lock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
// estUntransmittedData is the maximum number of bytes the sends might not have put
// on the wire yet. A value of 0 or less means that we have already received all or
// more bytes than the application is requesting to read.
estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
// This implies that unless we send a window update, the sender won't be able to send all the bytes
// for this message. Therefore we must send an update over the limit since there's an active read
// request from the application.
if estUntransmittedData > estSenderQuota {
// Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec.
if f.limit+n > maxWindowSize {
f.delta = maxWindowSize - f.limit
} else {
// Send a window update for the whole message and not just the difference between
// estUntransmittedData and estSenderQuota. This will be helpful in case the message
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n
}
f.mu.Unlock()
return f.delta
}
f.mu.Unlock()
return 0
}
// onData is invoked when some data frame is received. It updates pendingData.
func (f *inFlow) onData(n uint32) error {
f.mu.Lock()
f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit+f.delta {
limit := f.limit
rcvd := f.pendingData + f.pendingUpdate
f.mu.Unlock()
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit)
}
f.mu.Unlock()
return nil
}
// onRead is invoked when the application reads the data. It returns the window size
// to be sent to the peer.
func (f *inFlow) onRead(n uint32) uint32 {
f.mu.Lock()
if f.pendingData == 0 {
f.mu.Unlock()
return 0
}
f.pendingData -= n
if n > f.delta {
n -= f.delta
f.delta = 0
} else {
f.delta -= n
n = 0
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate
f.pendingUpdate = 0
f.mu.Unlock()
return wu
}
f.mu.Unlock()
return 0
}

View File

@ -0,0 +1,52 @@
// +build go1.6,!go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"net"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
}
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
}
// contextFromRequest returns a background context.
func contextFromRequest(r *http.Request) context.Context {
return context.Background()
}

View File

@ -0,0 +1,53 @@
// +build go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"context"
"net"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
netctx "golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error())
}
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
}
// contextFromRequest returns a context from the HTTP Request.
func contextFromRequest(r *http.Request) context.Context {
return r.Context()
}

View File

@ -0,0 +1,449 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// This file is the implementation of a gRPC server using HTTP/2 which
// uses the standard Go http2 Server implementation (via the
// http.Handler interface), rather than speaking low-level HTTP/2
// frames itself. It is the implementation of *grpc.Server.ServeHTTP.
package transport
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)
// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
}
if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method")
}
contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := contentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
}
if _, ok := w.(http.Flusher); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
}
if _, ok := w.(http.CloseNotifier); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier")
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
}
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
}
st.timeoutSet = true
st.timeout = to
}
metakv := []string{"content-type", contentType}
if r.Host != "" {
metakv = append(metakv, ":authority", r.Host)
}
for k, vv := range r.Header {
k = strings.ToLower(k)
if isReservedHeader(k) && !isWhitelistedHeader(k) {
continue
}
for _, v := range vv {
v, err := decodeMetadataHeader(k, v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
}
metakv = append(metakv, k, v)
}
}
st.headerMD = metadata.Pairs(metakv...)
return st, nil
}
// serverHandlerTransport is an implementation of ServerTransport
// which replies to exactly one gRPC request (exactly one HTTP request),
// using the net/http.Handler interface. This http.Handler is guaranteed
// at this point to be speaking over HTTP/2, so it's able to speak valid
// gRPC.
type serverHandlerTransport struct {
rw http.ResponseWriter
req *http.Request
timeoutSet bool
timeout time.Duration
didCommonHeaders bool
headerMD metadata.MD
closeOnce sync.Once
closedCh chan struct{} // closed on Close
// writes is a channel of code to run serialized in the
// ServeHTTP (HandleStreams) goroutine. The channel is closed
// when WriteStatus is called.
writes chan func()
// block concurrent WriteStatus calls
// e.g. grpc/(*serverStream).SendMsg/RecvMsg
writeStatusMu sync.Mutex
// we just mirror the request content-type
contentType string
// we store both contentType and contentSubtype so we don't keep recreating them
// TODO make sure this is consistent across handler_server and http2_server
contentSubtype string
stats stats.Handler
}
func (ht *serverHandlerTransport) Close() error {
ht.closeOnce.Do(ht.closeCloseChanOnce)
return nil
}
func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) }
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown.
type strAddr string
func (a strAddr) Network() string {
if a != "" {
// Per the documentation on net/http.Request.RemoteAddr, if this is
// set, it's set to the IP:port of the peer (hence, TCP):
// https://golang.org/pkg/net/http/#Request
//
// If we want to support Unix sockets later, we can
// add our own grpc-specific convention within the
// grpc codebase to set RemoteAddr to a different
// format, or probably better: we can attach it to the
// context and use that from serverHandlerTransport.RemoteAddr.
return "tcp"
}
return ""
}
func (a strAddr) String() string { return string(a) }
// do runs fn in the ServeHTTP goroutine.
func (ht *serverHandlerTransport) do(fn func()) error {
// Avoid a panic writing to closed channel. Imperfect but maybe good enough.
select {
case <-ht.closedCh:
return ErrConnClosing
default:
select {
case ht.writes <- fn:
return nil
case <-ht.closedCh:
return ErrConnClosing
}
}
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
ht.writeStatusMu.Lock()
defer ht.writeStatusMu.Unlock()
err := ht.do(func() {
ht.writeCommonHeaders(s)
// And flush, in case no header or body has been sent yet.
// This forces a separation of headers and trailers if this is the
// first call (for example, in end2end tests's TestNoService).
ht.rw.(http.Flusher).Flush()
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code()))
if m := st.Message(); m != "" {
h.Set("Grpc-Message", encodeGrpcMessage(m))
}
if p := st.Proto(); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
panic(err)
}
h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes))
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
// http2 ResponseWriter mechanism to send undeclared Trailers after
// the headers have possibly been written.
h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v))
}
}
}
})
if err == nil { // transport has not been closed
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
close(ht.writes)
}
ht.Close()
return err
}
// writeCommonHeaders sets common headers on the first write
// call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
if ht.didCommonHeaders {
return
}
ht.didCommonHeaders = true
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", ht.contentType)
// Predeclare trailers we'll set later in WriteStatus (after the body).
// This is a SHOULD in the HTTP RFC, and the way you add (known)
// Trailers per the net/http.ResponseWriter contract.
// See https://golang.org/pkg/net/http/#ResponseWriter
// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
h.Add("Trailer", "Grpc-Status")
h.Add("Trailer", "Grpc-Message")
h.Add("Trailer", "Grpc-Status-Details-Bin")
if s.sendCompress != "" {
h.Set("Grpc-Encoding", s.sendCompress)
}
}
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
return ht.do(func() {
ht.writeCommonHeaders(s)
ht.rw.Write(hdr)
ht.rw.Write(data)
ht.rw.(http.Flusher).Flush()
})
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
err := ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
v = encodeMetadataHeader(k, v)
h.Add(k, v)
}
}
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})
if err == nil {
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
}
}
return err
}
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
// With this transport type there will be exactly 1 stream: this HTTP request.
ctx := contextFromRequest(ht.req)
var cancel context.CancelFunc
if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
} else {
ctx, cancel = context.WithCancel(ctx)
}
// requestOver is closed when either the request's context is done
// or the status has been written via WriteStatus.
requestOver := make(chan struct{})
// clientGone receives a single value if peer is gone, either
// because the underlying connection is dead or because the
// peer sends an http2 RST_STREAM.
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
go func() {
select {
case <-requestOver:
case <-ht.closedCh:
case <-clientGone:
}
cancel()
ht.Close()
}()
req := ht.req
s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
if ht.stats != nil {
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
}
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
go func() {
defer close(readerDone)
// TODO: minimize garbage, optimize recvBuffer code/ownership
const readSize = 8196
for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf)
if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]})
buf = buf[n:]
}
if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
return
}
if len(buf) == 0 {
buf = make([]byte, readSize)
}
}
}()
// startStream is provided by the *grpc.Server's serveStreams.
// It starts a goroutine serving s and exits immediately.
// The goroutine that is started is the one that then calls
// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
startStream(s)
ht.runStream()
close(requestOver)
// Wait for reading goroutine to finish.
req.Body.Close()
<-readerDone
}
func (ht *serverHandlerTransport) runStream() {
for {
select {
case fn, ok := <-ht.writes:
if !ok {
return
}
fn()
case <-ht.closedCh:
return
}
}
}
func (ht *serverHandlerTransport) IncrMsgSent() {}
func (ht *serverHandlerTransport) IncrMsgRecv() {}
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be:
// * io.EOF
// * io.ErrUnexpectedEOF
// * of type transport.ConnectionError
// * an error from the status package
func mapRecvMsgError(err error) error {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return err
}
if se, ok := err.(http2.StreamError); ok {
if code, ok := http2ErrConvTab[se.Code]; ok {
return status.Error(code, se.Error())
}
}
if strings.Contains(err.Error(), "body closed by handler") {
return status.Error(codes.Canceled, err.Error())
}
return connectionErrorf(true, err, err.Error())
}

View File

@ -0,0 +1,481 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
epb "google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
type testCase struct {
name string
req *http.Request
wantErr string
modrw func(http.ResponseWriter) http.ResponseWriter
check func(*serverHandlerTransport, *testCase) error
}
tests := []testCase{
{
name: "http/1.1",
req: &http.Request{
ProtoMajor: 1,
ProtoMinor: 1,
},
wantErr: "gRPC requires HTTP/2",
},
{
name: "bad method",
req: &http.Request{
ProtoMajor: 2,
Method: "GET",
Header: http.Header{},
RequestURI: "/",
},
wantErr: "invalid gRPC request method",
},
{
name: "bad content type",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/foo"},
},
RequestURI: "/service/foo.bar",
},
wantErr: "invalid gRPC request content-type",
},
{
name: "not flusher",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/grpc"},
},
RequestURI: "/service/foo.bar",
},
modrw: func(w http.ResponseWriter) http.ResponseWriter {
// Return w without its Flush method
type onlyCloseNotifier interface {
http.ResponseWriter
http.CloseNotifier
}
return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
},
wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
},
{
name: "not closenotifier",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/grpc"},
},
RequestURI: "/service/foo.bar",
},
modrw: func(w http.ResponseWriter) http.ResponseWriter {
// Return w without its CloseNotify method
type onlyFlusher interface {
http.ResponseWriter
http.Flusher
}
return struct{ onlyFlusher }{w.(onlyFlusher)}
},
wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier",
},
{
name: "valid",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/grpc"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
},
check: func(t *serverHandlerTransport, tt *testCase) error {
if t.req != tt.req {
return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
}
if t.rw == nil {
return errors.New("t.rw = nil; want non-nil")
}
return nil
},
},
{
name: "with timeout",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": []string{"application/grpc"},
"Grpc-Timeout": {"200m"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
},
check: func(t *serverHandlerTransport, tt *testCase) error {
if !t.timeoutSet {
return errors.New("timeout not set")
}
if want := 200 * time.Millisecond; t.timeout != want {
return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
}
return nil
},
},
{
name: "with bad timeout",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": []string{"application/grpc"},
"Grpc-Timeout": {"tomorrow"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
},
wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
},
{
name: "with metadata",
req: &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": []string{"application/grpc"},
"meta-foo": {"foo-val"},
"meta-bar": {"bar-val1", "bar-val2"},
"user-agent": {"x/y a/b"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
},
check: func(ht *serverHandlerTransport, tt *testCase) error {
want := metadata.MD{
"meta-bar": {"bar-val1", "bar-val2"},
"user-agent": {"x/y a/b"},
"meta-foo": {"foo-val"},
"content-type": {"application/grpc"},
}
if !reflect.DeepEqual(ht.headerMD, want) {
return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
}
return nil
},
},
}
for _, tt := range tests {
rw := newTestHandlerResponseWriter()
if tt.modrw != nil {
rw = tt.modrw(rw)
}
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
continue
}
if gotErr != nil {
continue
}
if tt.check != nil {
if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
t.Errorf("%s: %v", tt.name, err)
}
}
}
}
type testHandlerResponseWriter struct {
*httptest.ResponseRecorder
closeNotify chan bool
}
func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
func (w testHandlerResponseWriter) Flush() {}
func newTestHandlerResponseWriter() http.ResponseWriter {
return testHandlerResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
closeNotify: make(chan bool, 1),
}
}
type handleStreamTest struct {
t *testing.T
bodyw *io.PipeWriter
rw testHandlerResponseWriter
ht *serverHandlerTransport
}
func newHandleStreamTest(t *testing.T) *handleStreamTest {
bodyr, bodyw := io.Pipe()
req := &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/grpc"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}
return &handleStreamTest{
t: t,
bodyw: bodyw,
ht: ht.(*serverHandlerTransport),
rw: rw,
}
}
func TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Grpc-Status": {"0"},
}
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
}
}
// Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
}
// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
}
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
}
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
}
}
func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
bodyr, bodyw := io.Pipe()
req := &http.Request{
ProtoMajor: 2,
Method: "POST",
Header: http.Header{
"Content-Type": {"application/grpc"},
"Grpc-Timeout": {"200m"},
},
URL: &url.URL{
Path: "/service/foo.bar",
},
RequestURI: "/service/foo.bar",
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}
runStream := func(s *Stream) {
defer bodyw.Close()
select {
case <-s.ctx.Done():
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for ctx.Done")
return
}
err := s.ctx.Err()
if err != context.DeadlineExceeded {
t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
return
}
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
func(s *Stream) { go runStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Grpc-Status": {"4"},
"Grpc-Message": {encodeGrpcMessage("too slow")},
}
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
}
}
// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}()
}
wg.Wait()
})
}
// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
// following "WriteStatus" does not panic writing to closed "writes" channel.
func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
})
}
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(
func(s *Stream) { go handleStream(st, s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
}
func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
errDetails := []proto.Message{
&epb.RetryInfo{
RetryDelay: &dpb.Duration{Seconds: 60},
},
&epb.ResourceInfo{
ResourceType: "foo bar",
ResourceName: "service.foo.bar",
Owner: "User",
},
}
statusCode := codes.ResourceExhausted
msg := "you are being throttled"
st, err := status.New(statusCode, msg).WithDetails(errDetails...)
if err != nil {
t.Fatal(err)
}
stBytes, err := proto.Marshal(st.Proto())
if err != nil {
t.Fatal(err)
}
hst := newHandleStreamTest(t)
handleStream := func(s *Stream) {
hst.ht.WriteStatus(s, st)
}
hst.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
}
if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader)
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,623 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"bufio"
"bytes"
"encoding/base64"
"fmt"
"io"
"math"
"net"
"net/http"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
// http2MaxFrameLen specifies the max length of a HTTP2 frame.
http2MaxFrameLen = 16384 // 16KB frame
// http://http2.github.io/http2-spec/#SettingValues
http2InitHeaderTableSize = 4096
// baseContentType is the base content-type for gRPC. This is a valid
// content-type on it's own, but can also include a content-subtype such as
// "proto" as a suffix after "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
baseContentType = "application/grpc"
)
var (
clientPreface = []byte(http2.ClientPreface)
http2ErrConvTab = map[http2.ErrCode]codes.Code{
http2.ErrCodeNo: codes.Internal,
http2.ErrCodeProtocol: codes.Internal,
http2.ErrCodeInternal: codes.Internal,
http2.ErrCodeFlowControl: codes.ResourceExhausted,
http2.ErrCodeSettingsTimeout: codes.Internal,
http2.ErrCodeStreamClosed: codes.Internal,
http2.ErrCodeFrameSize: codes.Internal,
http2.ErrCodeRefusedStream: codes.Unavailable,
http2.ErrCodeCancel: codes.Canceled,
http2.ErrCodeCompression: codes.Internal,
http2.ErrCodeConnect: codes.Internal,
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
http2.ErrCodeHTTP11Required: codes.Internal,
}
statusCodeConvTab = map[codes.Code]http2.ErrCode{
codes.Internal: http2.ErrCodeInternal,
codes.Canceled: http2.ErrCodeCancel,
codes.Unavailable: http2.ErrCodeRefusedStream,
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
}
httpStatusConvTab = map[int]codes.Code{
// 400 Bad Request - INTERNAL.
http.StatusBadRequest: codes.Internal,
// 401 Unauthorized - UNAUTHENTICATED.
http.StatusUnauthorized: codes.Unauthenticated,
// 403 Forbidden - PERMISSION_DENIED.
http.StatusForbidden: codes.PermissionDenied,
// 404 Not Found - UNIMPLEMENTED.
http.StatusNotFound: codes.Unimplemented,
// 429 Too Many Requests - UNAVAILABLE.
http.StatusTooManyRequests: codes.Unavailable,
// 502 Bad Gateway - UNAVAILABLE.
http.StatusBadGateway: codes.Unavailable,
// 503 Service Unavailable - UNAVAILABLE.
http.StatusServiceUnavailable: codes.Unavailable,
// 504 Gateway timeout - UNAVAILABLE.
http.StatusGatewayTimeout: codes.Unavailable,
}
)
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
encoding string
// statusGen caches the stream status received from the trailer the server
// sent. Client side only. Do not access directly. After all trailers are
// parsed, use the status method to retrieve the status.
statusGen *status.Status
// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
// intended for direct access outside of parsing.
rawStatusCode *int
rawStatusMsg string
httpStatus *int
// Server side only fields.
timeoutSet bool
timeout time.Duration
method string
// key-value metadata map from the peer.
mdata map[string][]string
statsTags []byte
statsTrace []byte
contentSubtype string
// whether decoding on server side or not
serverSide bool
}
// isReservedHeader checks whether hdr belongs to HTTP2 headers
// reserved by gRPC protocol. Any other headers are classified as the
// user-specified metadata.
func isReservedHeader(hdr string) bool {
if hdr != "" && hdr[0] == ':' {
return true
}
switch hdr {
case "content-type",
"user-agent",
"grpc-message-type",
"grpc-encoding",
"grpc-message",
"grpc-status",
"grpc-timeout",
"grpc-status-details-bin",
// Intentionally exclude grpc-previous-rpc-attempts and
// grpc-retry-pushback-ms, which are "reserved", but their API
// intentionally works via metadata.
"te":
return true
default:
return false
}
}
// isWhitelistedHeader checks whether hdr should be propagated into metadata
// visible to users, even though it is classified as "reserved", above.
func isWhitelistedHeader(hdr string) bool {
switch hdr {
case ":authority", "user-agent":
return true
default:
return false
}
}
// contentSubtype returns the content-subtype for the given content-type. The
// given content-type must be a valid content-type that starts with
// "application/grpc". A content-subtype will follow "application/grpc" after a
// "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If contentType is not a valid content-type for gRPC, the boolean
// will be false, otherwise true. If content-type == "application/grpc",
// "application/grpc+", or "application/grpc;", the boolean will be true,
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func contentSubtype(contentType string) (string, bool) {
if contentType == baseContentType {
return "", true
}
if !strings.HasPrefix(contentType, baseContentType) {
return "", false
}
// guaranteed since != baseContentType and has baseContentType prefix
switch contentType[len(baseContentType)] {
case '+', ';':
// this will return true for "application/grpc+" or "application/grpc;"
// which the previous validContentType function tested to be valid, so we
// just say that no content-subtype is specified in this case
return contentType[len(baseContentType)+1:], true
default:
return "", false
}
}
// contentSubtype is assumed to be lowercase
func contentType(contentSubtype string) string {
if contentSubtype == "" {
return baseContentType
}
return baseContentType + "+" + contentSubtype
}
func (d *decodeState) status() *status.Status {
if d.statusGen == nil {
// No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
}
return d.statusGen
}
const binHdrSuffix = "-bin"
func encodeBinHeader(v []byte) string {
return base64.RawStdEncoding.EncodeToString(v)
}
func decodeBinHeader(v string) ([]byte, error) {
if len(v)%4 == 0 {
// Input was padded, or padding was not necessary.
return base64.StdEncoding.DecodeString(v)
}
return base64.RawStdEncoding.DecodeString(v)
}
func encodeMetadataHeader(k, v string) string {
if strings.HasSuffix(k, binHdrSuffix) {
return encodeBinHeader(([]byte)(v))
}
return v
}
func decodeMetadataHeader(k, v string) (string, error) {
if strings.HasSuffix(k, binHdrSuffix) {
b, err := decodeBinHeader(v)
return string(b), err
}
return v, nil
}
func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error {
// frame.Truncated is set to true when framer detects that the current header
// list size hits MaxHeaderListSize limit.
if frame.Truncated {
return status.Error(codes.Internal, "peer header list size exceeded limit")
}
for _, hf := range frame.Fields {
if err := d.processHeaderField(hf); err != nil {
return err
}
}
if d.serverSide {
return nil
}
// If grpc status exists, no need to check further.
if d.rawStatusCode != nil || d.statusGen != nil {
return nil
}
// If grpc status doesn't exist and http status doesn't exist,
// then it's a malformed header.
if d.httpStatus == nil {
return status.Error(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
}
if *(d.httpStatus) != http.StatusOK {
code, ok := httpStatusConvTab[*(d.httpStatus)]
if !ok {
code = codes.Unknown
}
return status.Error(code, http.StatusText(*(d.httpStatus)))
}
// gRPC status doesn't exist and http status is OK.
// Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status
// will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown)
d.rawStatusCode = &code
return nil
}
func (d *decodeState) addMetadata(k, v string) {
if d.mdata == nil {
d.mdata = make(map[string][]string)
}
d.mdata[k] = append(d.mdata[k], v)
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name {
case "content-type":
contentSubtype, validContentType := contentSubtype(f.Value)
if !validContentType {
return status.Errorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
}
d.contentSubtype = contentSubtype
// TODO: do we want to propagate the whole content-type in the metadata,
// or come up with a way to just propagate the content-subtype if it was set?
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
// in the metadata?
d.addMetadata(f.Name, f.Value)
case "grpc-encoding":
d.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
}
d.rawStatusCode = &code
case "grpc-message":
d.rawStatusMsg = decodeGrpcMessage(f.Value)
case "grpc-status-details-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
}
s := &spb.Status{}
if err := proto.Unmarshal(v, s); err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
}
d.statusGen = status.FromProto(s)
case "grpc-timeout":
d.timeoutSet = true
var err error
if d.timeout, err = decodeTimeout(f.Value); err != nil {
return status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
}
case ":path":
d.method = f.Value
case ":status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
}
d.httpStatus = &code
case "grpc-tags-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
}
d.statsTags = v
d.addMetadata(f.Name, string(v))
case "grpc-trace-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
}
d.statsTrace = v
d.addMetadata(f.Name, string(v))
default:
if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
break
}
v, err := decodeMetadataHeader(f.Name, f.Value)
if err != nil {
errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
return nil
}
d.addMetadata(f.Name, v)
}
return nil
}
type timeoutUnit uint8
const (
hour timeoutUnit = 'H'
minute timeoutUnit = 'M'
second timeoutUnit = 'S'
millisecond timeoutUnit = 'm'
microsecond timeoutUnit = 'u'
nanosecond timeoutUnit = 'n'
)
func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
switch u {
case hour:
return time.Hour, true
case minute:
return time.Minute, true
case second:
return time.Second, true
case millisecond:
return time.Millisecond, true
case microsecond:
return time.Microsecond, true
case nanosecond:
return time.Nanosecond, true
default:
}
return
}
const maxTimeoutValue int64 = 100000000 - 1
// div does integer division and round-up the result. Note that this is
// equivalent to (d+r-1)/r but has less chance to overflow.
func div(d, r time.Duration) int64 {
if m := d % r; m > 0 {
return int64(d/r + 1)
}
return int64(d / r)
}
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func encodeTimeout(t time.Duration) string {
if t <= 0 {
return "0n"
}
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
if d := div(t, time.Microsecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "u"
}
if d := div(t, time.Millisecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "m"
}
if d := div(t, time.Second); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "S"
}
if d := div(t, time.Minute); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "M"
}
// Note that maxTimeoutValue * time.Hour > MaxInt64.
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}
func decodeTimeout(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
}
if size > 9 {
// Spec allows for 8 digits plus the unit.
return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
}
unit := timeoutUnit(s[size-1])
d, ok := timeoutUnitToDuration(unit)
if !ok {
return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
}
t, err := strconv.ParseInt(s[:size-1], 10, 64)
if err != nil {
return 0, err
}
const maxHours = math.MaxInt64 / int64(time.Hour)
if d == time.Hour && t > maxHours {
// This timeout would overflow math.MaxInt64; clamp it.
return time.Duration(math.MaxInt64), nil
}
return d * time.Duration(t), nil
}
const (
spaceByte = ' '
tildeByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message". It does percent encoding and also replaces invalid utf-8
// characters with Unicode replacement character.
//
// It checks to see if each individual byte in msg is an allowable byte, and
// then either percent encoding or passing it through. When percent encoding,
// the byte is converted into hexadecimal notation with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
for len(msg) > 0 {
r, size := utf8.DecodeRuneInString(msg)
for _, b := range []byte(string(r)) {
if size > 1 {
// If size > 1, r is not ascii. Always do percent encoding.
buf.WriteString(fmt.Sprintf("%%%02X", b))
continue
}
// The for loop is necessary even if size == 1. r could be
// utf8.RuneError.
//
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
if b >= spaceByte && b <= tildeByte && b != percentByte {
buf.WriteByte(b)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", b))
}
}
msg = msg[size:]
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type bufWriter struct {
buf []byte
offset int
batchSize int
conn net.Conn
err error
onFlush func()
}
func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
return &bufWriter{
buf: make([]byte, batchSize*2),
batchSize: batchSize,
conn: conn,
}
}
func (w *bufWriter) Write(b []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if w.batchSize == 0 { // Buffer has been disabled.
return w.conn.Write(b)
}
for len(b) > 0 {
nn := copy(w.buf[w.offset:], b)
b = b[nn:]
w.offset += nn
n += nn
if w.offset >= w.batchSize {
err = w.Flush()
}
}
return n, err
}
func (w *bufWriter) Flush() error {
if w.err != nil {
return w.err
}
if w.offset == 0 {
return nil
}
if w.onFlush != nil {
w.onFlush()
}
_, w.err = w.conn.Write(w.buf[:w.offset])
w.offset = 0
return w.err
}
type framer struct {
writer *bufWriter
fr *http2.Framer
}
func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer {
if writeBufferSize < 0 {
writeBufferSize = 0
}
var r io.Reader = conn
if readBufferSize > 0 {
r = bufio.NewReaderSize(r, readBufferSize)
}
w := newBufWriter(conn, writeBufferSize)
f := &framer{
writer: w,
fr: http2.NewFramer(w, r),
}
// Opt-in to Frame reuse API on framer to reduce garbage.
// Frames aren't safe to read from after a subsequent call to ReadFrame.
f.fr.SetReuseFrames()
f.fr.MaxHeaderListSize = maxHeaderListSize
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f
}

View File

@ -0,0 +1,237 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"fmt"
"reflect"
"testing"
"time"
)
func TestTimeoutEncode(t *testing.T) {
for _, test := range []struct {
in string
out string
}{
{"12345678ns", "12345678n"},
{"123456789ns", "123457u"},
{"12345678us", "12345678u"},
{"123456789us", "123457m"},
{"12345678ms", "12345678m"},
{"123456789ms", "123457S"},
{"12345678s", "12345678S"},
{"123456789s", "2057614M"},
{"12345678m", "12345678M"},
{"123456789m", "2057614H"},
} {
d, err := time.ParseDuration(test.in)
if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
}
out := encodeTimeout(d)
if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
}
}
}
func TestTimeoutDecode(t *testing.T) {
for _, test := range []struct {
// input
s string
// output
d time.Duration
err error
}{
{"1234S", time.Second * 1234, nil},
{"1234x", 0, fmt.Errorf("transport: timeout unit is not recognized: %q", "1234x")},
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
} {
d, err := decodeTimeout(test.s)
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
}
}
}
func TestContentSubtype(t *testing.T) {
tests := []struct {
contentType string
want string
wantValid bool
}{
{"application/grpc", "", true},
{"application/grpc+", "", true},
{"application/grpc+blah", "blah", true},
{"application/grpc;", "", true},
{"application/grpc;blah", "blah", true},
{"application/grpcd", "", false},
{"application/grpd", "", false},
{"application/grp", "", false},
}
for _, tt := range tests {
got, gotValid := contentSubtype(tt.contentType)
if got != tt.want || gotValid != tt.wantValid {
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
}
}
}
func TestEncodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"\u0000", "%00"},
{"%", "%25"},
{"系统", "%E7%B3%BB%E7%BB%9F"},
{string([]byte{0xff, 0xfe, 0xfd}), "%EF%BF%BD%EF%BF%BD%EF%BF%BD"},
} {
actual := encodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("encodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
}
}
// make sure that all the visible ASCII chars except '%' are not percent encoded.
for i := ' '; i <= '~' && i != '%'; i++ {
output := encodeGrpcMessage(string(i))
if output != string(i) {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", string(i), output, string(i))
}
}
// make sure that all the invisible ASCII chars and '%' are percent encoded.
for i := rune(0); i == '%' || (i >= rune(0) && i < ' ') || (i > '~' && i <= rune(127)); i++ {
output := encodeGrpcMessage(string(i))
expected := fmt.Sprintf("%%%02X", i)
if output != expected {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", string(i), output, expected)
}
}
}
func TestDecodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"H%61o", "Hao"},
{"H%6", "H%6"},
{"%G0", "%G0"},
{"%E7%B3%BB%E7%BB%9F", "系统"},
{"%EF%BF%BD", "<22>"},
} {
actual := decodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("decodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
}
}
// make sure that all the visible ASCII chars except '%' are not percent decoded.
for i := ' '; i <= '~' && i != '%'; i++ {
output := decodeGrpcMessage(string(i))
if output != string(i) {
t.Errorf("decodeGrpcMessage(%v) = %v, want %v", string(i), output, string(i))
}
}
// make sure that all the invisible ASCII chars and '%' are percent decoded.
for i := rune(0); i == '%' || (i >= rune(0) && i < ' ') || (i > '~' && i <= rune(127)); i++ {
output := decodeGrpcMessage(fmt.Sprintf("%%%02X", i))
if output != string(i) {
t.Errorf("decodeGrpcMessage(%v) = %v, want %v", fmt.Sprintf("%%%02X", i), output, string(i))
}
}
}
// Decode an encoded string should get the same thing back, except for invalid
// utf8 chars.
func TestDecodeEncodeGrpcMessage(t *testing.T) {
testCases := []struct {
orig string
want string
}{
{"", ""},
{"hello", "hello"},
{"h%6", "h%6"},
{"%G0", "%G0"},
{"系统", "系统"},
{"Hello, 世界", "Hello, 世界"},
{string([]byte{0xff, 0xfe, 0xfd}), "<22><><EFBFBD>"},
{string([]byte{0xff}) + "Hello" + string([]byte{0xfe}) + "世界" + string([]byte{0xfd}), "<22>Hello<6C>世界<E4B896>"},
}
for _, tC := range testCases {
got := decodeGrpcMessage(encodeGrpcMessage(tC.orig))
if got != tC.want {
t.Errorf("decodeGrpcMessage(encodeGrpcMessage(%q)) = %q, want %q", tC.orig, got, tC.want)
}
}
}
const binaryValue = string(128)
func TestEncodeMetadataHeader(t *testing.T) {
for _, test := range []struct {
// input
kin string
vin string
// output
vout string
}{
{"key", "abc", "abc"},
{"KEY", "abc", "abc"},
{"key-bin", "abc", "YWJj"},
{"key-bin", binaryValue, "woA"},
} {
v := encodeMetadataHeader(test.kin, test.vin)
if !reflect.DeepEqual(v, test.vout) {
t.Fatalf("encodeMetadataHeader(%q, %q) = %q, want %q", test.kin, test.vin, v, test.vout)
}
}
}
func TestDecodeMetadataHeader(t *testing.T) {
for _, test := range []struct {
// input
kin string
vin string
// output
vout string
err error
}{
{"a", "abc", "abc", nil},
{"key-bin", "Zm9vAGJhcg==", "foo\x00bar", nil},
{"key-bin", "Zm9vAGJhcg", "foo\x00bar", nil},
{"key-bin", "woA=", binaryValue, nil},
{"a", "abc,efg", "abc,efg", nil},
} {
v, err := decodeMetadataHeader(test.kin, test.vin)
if !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {
t.Fatalf("decodeMetadataHeader(%q, %q) = %q, %v, want %q, %v", test.kin, test.vin, v, err, test.vout, test.err)
}
}
}

View File

@ -0,0 +1,44 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// This file contains wrappers for grpclog functions.
// The transport package only logs to verbose level 2 by default.
package transport
import "google.golang.org/grpc/grpclog"
const logLevel = 2
func infof(format string, args ...interface{}) {
if grpclog.V(logLevel) {
grpclog.Infof(format, args...)
}
}
func warningf(format string, args ...interface{}) {
if grpclog.V(logLevel) {
grpclog.Warningf(format, args...)
}
}
func errorf(format string, args ...interface{}) {
if grpclog.V(logLevel) {
grpclog.Errorf(format, args...)
}
}

View File

@ -0,0 +1,712 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package transport defines and implements message oriented communication
// channel to complete various transactions (e.g., an RPC). It is meant for
// grpc-internal usage and is not intended to be imported directly by users.
package transport
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
)
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
data []byte
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
err error
}
// recvBuffer is an unbounded channel of recvMsg structs.
// Note recvBuffer differs from controlBuffer only in that recvBuffer
// holds a channel of only recvMsg structs instead of objects implementing "item" interface.
// recvBuffer is written to much more often than
// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put"
type recvBuffer struct {
c chan recvMsg
mu sync.Mutex
backlog []recvMsg
err error
}
func newRecvBuffer() *recvBuffer {
b := &recvBuffer{
c: make(chan recvMsg, 1),
}
return b
}
func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock()
if b.err != nil {
b.mu.Unlock()
// An error had occurred earlier, don't accept more
// data or errors.
return
}
b.err = r.err
if len(b.backlog) == 0 {
select {
case b.c <- r:
b.mu.Unlock()
return
default:
}
}
b.backlog = append(b.backlog, r)
b.mu.Unlock()
}
func (b *recvBuffer) load() {
b.mu.Lock()
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = recvMsg{}
b.backlog = b.backlog[1:]
default:
}
}
b.mu.Unlock()
}
// get returns the channel that receives a recvMsg in the buffer.
//
// Upon receipt of a recvMsg, the caller should call load to send another
// recvMsg onto the channel if there is any.
func (b *recvBuffer) get() <-chan recvMsg {
return b.c
}
//
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last []byte // Stores the remaining data in the previous calls.
err error
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, r.err = r.read(p)
return n, r.err
}
func (r *recvBufferReader) read(p []byte) (n int, err error) {
if r.last != nil && len(r.last) > 0 {
// Read remaining data left in last call.
copied := copy(p, r.last)
r.last = r.last[copied:]
return copied, nil
}
select {
case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
r.recv.load()
if m.err != nil {
return 0, m.err
}
copied := copy(p, m.data)
r.last = m.data[copied:]
return copied, nil
}
}
type streamState uint32
const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // the entire stream is finished.
)
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
trReader io.Reader
fc *inFlow
wq *writeQuota
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
noHeaders bool // set if the client never received headers (set only after the stream is done).
// On the server-side, headerSent is atomically set to 1 when the headers are sent out.
headerSent uint32
state streamState
// On client-side it is the status error received from the server.
// On server-side it is unused.
status *status.Status
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
// isHeaderSent is only valid on the server-side.
func (s *Stream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
}
// updateHeaderSent updates headerSent and returns true
// if it was alreay set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
}
return nil, err
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil. If a context error happens
// first, returns it as a status error. Client-side only.
func (s *Stream) TrailersOnly() (bool, error) {
err := s.waitOnHeader()
if err != nil {
return false, err
}
// if !headerDone, some other connection error occurred.
return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *Stream) Status() *status.Status {
return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
s.hdrMu.Unlock()
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
return s.st.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.hdrMu.Unlock()
return nil
}
func (s *Stream) write(m recvMsg) {
s.buf.put(m)
}
// Read reads all p bytes from the wire for this stream.
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil {
return 0, er
}
s.requestRead(len(p))
return io.ReadFull(s.trReader, p)
}
// tranportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
type transportReader struct {
reader io.Reader
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
er error
}
func (t *transportReader) Read(p []byte) (n int, err error) {
n, err = t.reader.Read(p)
if err != nil {
t.er = err
return
}
t.windowHandler(n)
return
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}
// state of transport
type transportState int
const (
reachable transportState = iota
closing
draining
)
// ServerConfig consists of all the configurations to establish a server transport.
type ServerConfig struct {
MaxStreams uint32
AuthInfo credentials.AuthInfo
InTapHandle tap.ServerInHandle
StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters
KeepalivePolicy keepalive.EnforcementPolicy
InitialWindowSize int32
InitialConnWindowSize int32
WriteBufferSize int
ReadBufferSize int
ChannelzParentID int64
MaxHeaderListSize *uint32
}
// NewServerTransport creates a ServerTransport with conn or non-nil error
// if it fails.
func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) {
return newHTTP2Server(conn, config)
}
// ConnectOptions covers all relevant options for communicating with the server.
type ConnectOptions struct {
// UserAgent is the application user agent.
UserAgent string
// Dialer specifies how to dial a network address.
Dialer func(context.Context, string) (net.Conn, error)
// FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors.
FailOnNonTempDialError bool
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client
// connection. Only one of TransportCredentials and CredsBundle is non-nil.
TransportCredentials credentials.TransportCredentials
// CredsBundle is the credentials bundle to be used. Only one of
// TransportCredentials and CredsBundle is non-nil.
CredsBundle credentials.Bundle
// KeepaliveParams stores the keepalive parameters.
KeepaliveParams keepalive.ClientParameters
// StatsHandler stores the handler for stats.
StatsHandler stats.Handler
// InitialWindowSize sets the initial window size for a stream.
InitialWindowSize int32
// InitialConnWindowSize sets the initial window size for a connection.
InitialConnWindowSize int32
// WriteBufferSize sets the size of write buffer which in turn determines how much data can be batched before it's written on the wire.
WriteBufferSize int
// ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall.
ReadBufferSize int
// ChannelzParentID sets the addrConn id which initiate the creation of this client transport.
ChannelzParentID int64
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32
}
// TargetInfo contains the information of the target such as network address and metadata.
type TargetInfo struct {
Addr string
Metadata interface{}
Authority string
}
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onSuccess func(), onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, target, opts, onSuccess, onGoAway, onClose)
}
// Options provides additional hints and information for message
// transmission.
type Options struct {
// Last indicates whether this write is the last piece for
// this stream.
Last bool
}
// CallHdr carries the information of a particular RPC.
type CallHdr struct {
// Host specifies the peer's host.
Host string
// Method specifies the operation to perform.
Method string
// SendCompress specifies the compression algorithm applied on
// outbound message.
SendCompress string
// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials
// ContentSubtype specifies the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". The value of ContentSubtype must be all
// lowercase, otherwise the behavior is undefined. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
ContentSubtype string
PreviousAttempts int // value of grpc-previous-rpc-attempts header to set
}
// ClientTransport is the common interface for all gRPC client-side transport
// implementations.
type ClientTransport interface {
// Close tears down this transport. Once it returns, the transport
// should not be accessed any more. The caller must make sure this
// is called only once.
Close() error
// GracefulClose starts to tear down the transport. It stops accepting
// new RPCs and wait the completion of the pending RPCs.
GracefulClose() error
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
// CloseStream clears the footprint of a stream when the stream is
// not needed any more. The err indicates the error incurred when
// CloseStream is called. Must be called when a stream is finished
// unless the associated transport is closing.
CloseStream(stream *Stream, err error)
// Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor
// this in order to take action (e.g., close the current transport
// and create a new one) in error case. It should not return nil
// once the transport is initiated.
Error() <-chan struct{}
// GoAway returns a channel that is closed when ClientTransport
// receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2).
GoAway() <-chan struct{}
// GetGoAwayReason returns the reason why GoAway frame was received.
GetGoAwayReason() GoAwayReason
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
}
// ServerTransport is the common interface for all gRPC server-side transport
// implementations.
//
// Methods may be called concurrently from multiple goroutines, but
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream), func(context.Context, string) context.Context)
// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
WriteHeader(s *Stream, md metadata.MD) error
// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
WriteStatus(s *Stream, st *status.Status) error
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
Close() error
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain()
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
}
// connectionErrorf creates an ConnectionError with the specified error description.
func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{
Desc: fmt.Sprintf(format, a...),
temp: temp,
err: e,
}
}
// ConnectionError is an error that results in the termination of the
// entire connection and the retry of all the active streams.
type ConnectionError struct {
Desc string
temp bool
err error
}
func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc)
}
// Temporary indicates if this connection error is temporary or fatal.
func (e ConnectionError) Temporary() bool {
return e.temp
}
// Origin returns the original error of this connection error.
func (e ConnectionError) Origin() error {
// Never return nil error here.
// If the original error is nil, return itself.
if e.err == nil {
return e
}
return e.err
}
var (
// ErrConnClosing indicates that the transport is closing.
ErrConnClosing = connectionErrorf(true, nil, "transport is closing")
// errStreamDrain indicates that the stream is rejected because the
// connection is draining. This could be caused by goaway or balancer
// removing the address.
errStreamDrain = status.Error(codes.Unavailable, "the connection is draining")
// errStreamDone is returned from write at the client side to indiacte application
// layer of an error.
errStreamDone = errors.New("the stream is done")
// StatusGoAway indicates that the server sent a GOAWAY that included this
// stream's ID in unprocessed RPCs.
statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection")
)
// GoAwayReason contains the reason for the GoAway frame received.
type GoAwayReason uint8
const (
// GoAwayInvalid indicates that no GoAway frame is received.
GoAwayInvalid GoAwayReason = 0
// GoAwayNoReason is the default value when GoAway frame is received.
GoAwayNoReason GoAwayReason = 1
// GoAwayTooManyPings indicates that a GoAway frame with
// ErrCodeEnhanceYourCalm was received and that the debug data said
// "too_many_pings".
GoAwayTooManyPings GoAwayReason = 2
)
// channelzData is used to store channelz related data for http2Client and http2Server.
// These fields cannot be embedded in the original structs (e.g. http2Client), since to do atomic
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment.
type channelzData struct {
kpCount int64
// The number of streams that have started, including already finished ones.
streamsStarted int64
// Client side: The number of streams that have ended successfully by receiving
// EoS bit set frame from server.
// Server side: The number of streams that have ended successfully by sending
// frame with EoS bit set.
streamsSucceeded int64
streamsFailed int64
// lastStreamCreatedTime stores the timestamp that the last stream gets created. It is of int64 type
// instead of time.Time since it's more costly to atomically update time.Time variable than int64
// variable. The same goes for lastMsgSentTime and lastMsgRecvTime.
lastStreamCreatedTime int64
msgSent int64
msgRecv int64
lastMsgSentTime int64
lastMsgRecvTime int64
}

File diff suppressed because it is too large Load Diff