vendor updates

This commit is contained in:
Serguei Bezverkhi
2018-03-06 17:33:18 -05:00
parent 4b3ebc171b
commit e9033989a0
5854 changed files with 248382 additions and 119809 deletions

View File

146
vendor/google.golang.org/grpc/Documentation/encoding.md generated vendored Normal file
View File

@ -0,0 +1,146 @@
# Encoding
The gRPC API for sending and receiving is based upon *messages*. However,
messages cannot be transmitted directly over a network; they must first be
converted into *bytes*. This document describes how gRPC-Go converts messages
into bytes and vice-versa for the purposes of network transmission.
## Codecs (Serialization and Deserialization)
A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and
deserialize a byte slice back into a message (`Unmarshal`). `Codec`s are
registered by name into a global registry maintained in the `encoding` package.
### Implementing a `Codec`
A typical `Codec` will be implemented in its own package with an `init` function
that registers itself, and is imported anonymously. For example:
```go
package proto
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCodec(protoCodec{})
}
// ... implementation of protoCodec ...
```
For an example, gRPC's implementation of the `proto` codec can be found in
[`encoding/proto`](https://godoc.org/google.golang.org/grpc/encoding/proto).
### Using a `Codec`
By default, gRPC registers and uses the "proto" codec, so it is not necessary to
do this in your own code to send and receive proto messages. To use another
`Codec` from a client or server:
```go
package myclient
import _ "path/to/another/codec"
```
`Codec`s, by definition, must be symmetric, so the same desired `Codec` should
be registered in both client and server binaries.
On the client-side, to specify a `Codec` to use for message transmission, the
`CallOption` `CallContentSubtype` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec")))
```
When specified in either of these ways, messages will be encoded using this
codec and sent along with headers indicating the codec (`content-type` set to
`application/grpc+<codec name>`).
On the server-side, using a `Codec` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is encoded with the content
sub-type supported by a registered `Codec`, it will be used automatically for
decoding the request and encoding the response. Otherwise, for
backward-compatibility reasons, gRPC will attempt to use the "proto" codec. In
an upcoming change (tracked in [this
issue](https://github.com/grpc/grpc-go/issues/1824)), such requests will be
rejected with status code `Unimplemented` instead.
## Compressors (Compression and Decompression)
Sometimes, the resulting serialization of a message is not space-efficient, and
it may be beneficial to compress this byte stream before transmitting it over
the network. To facilitate this operation, gRPC supports a mechanism for
performing compression and decompression.
A `Compressor` contains code to compress and decompress by wrapping `io.Writer`s
and `io.Reader`s, respectively. (The form of `Compress` and `Decompress` were
chosen to most closely match Go's standard package
[implementations](https://golang.org/pkg/compress/) of compressors. Like
`Codec`s, `Compressor`s are registered by name into a global registry maintained
in the `encoding` package.
### Implementing a `Compressor`
A typical `Compressor` will be implemented in its own package with an `init`
function that registers itself, and is imported anonymously. For example:
```go
package gzip
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCompressor(compressor{})
}
// ... implementation of compressor ...
```
An implementation of a `gzip` compressor can be found in
[`encoding/gzip`](https://godoc.org/google.golang.org/grpc/encoding/gzip).
### Using a `Compressor`
By default, gRPC does not register or use any compressors. To use a
`Compressor` from a client or server:
```go
package myclient
import _ "google.golang.org/grpc/encoding/gzip"
```
`Compressor`s, by definition, must be symmetric, so the same desired
`Compressor` should be registered in both client and server binaries.
On the client-side, to specify a `Compressor` to use for message transmission,
the `CallOption` `UseCompressor` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.UseCompresor("gzip")))
```
When specified in either of these ways, messages will be compressed using this
compressor and sent along with headers indicating the compressor
(`content-coding` set to `<compressor name>`).
On the server-side, using a `Compressor` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is compressed with the
content coding supported by a registered `Compressor`, it will be used
automatically for decompressing the request and compressing the response.
Otherwise, the request will be rejected with status code `Unimplemented`.

View File

@ -82,13 +82,16 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeRespo
### Sending metadata
To send metadata to server, the client can wrap the metadata into a context using `NewOutgoingContext`, and make the RPC with this context:
There are two ways to send metadata to the server. The recommended way is to append kv pairs to the context using
`AppendToOutgoingContext`. This can be used with or without existing metadata on the context. When there is no prior
metadata, metadata is added; when metadata already exists on the context, kv pairs are merged in.
```go
md := metadata.Pairs("key", "val")
// create a new context with some metadata
ctx := metadata.AppendToOutgoingContext(ctx, "k1", "v1", "k1", "v2", "k2", "v3")
// create a new context with this metadata
ctx := metadata.NewOutgoingContext(context.Background(), md)
// later, add some more metadata to the context (e.g. in an interceptor)
ctx := metadata.AppendToOutgoingContext(ctx, "k3", "v4")
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
@ -97,7 +100,27 @@ response, err := client.SomeRPC(ctx, someRequest)
stream, err := client.SomeStreamingRPC(ctx)
```
To read this back from the context on the client (e.g. in an interceptor) before the RPC is sent, use `FromOutgoingContext`.
Alternatively, metadata may be attached to the context using `NewOutgoingContext`. However, this
replaces any existing metadata in the context, so care must be taken to preserve the existing
metadata if desired. This is slower than using `AppendToOutgoingContext`. An example of this
is below:
```go
// create a new context with some metadata
md := metadata.Pairs("k1", "v1", "k1", "v2", "k2", "v3")
ctx := metadata.NewOutgoingContext(context.Background(), md)
// later, add some more metadata to the context (e.g. in an interceptor)
md, _ := metadata.FromOutgoingContext(ctx)
newMD := metadata.Pairs("k3", "v3")
ctx = metadata.NewContext(ctx, metadata.Join(metadata.New(send), newMD))
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```
### Receiving metadata

View File

@ -277,55 +277,82 @@ func BenchmarkAtomicTimePointerStore(b *testing.B) {
b.StopTimer()
}
func BenchmarkValueStoreWithContention(b *testing.B) {
func BenchmarkStoreContentionWithAtomic(b *testing.B) {
t := 123
for _, n := range []int{10, 100, 1000, 10000, 100000} {
b.Run(fmt.Sprintf("Atomic/%v", n), func(b *testing.B) {
var wg sync.WaitGroup
var c atomic.Value
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
for j := 0; j < b.N; j++ {
c.Store(t)
var c unsafe.Pointer
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
atomic.StorePointer(&c, unsafe.Pointer(&t))
}
})
}
func BenchmarkStoreContentionWithMutex(b *testing.B) {
t := 123
var mu sync.Mutex
var c int
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
mu.Lock()
c = t
mu.Unlock()
}
})
_ = c
}
type dummyStruct struct {
a int64
b time.Time
}
func BenchmarkStructStoreContention(b *testing.B) {
d := dummyStruct{}
dp := unsafe.Pointer(&d)
t := time.Now()
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("CAS/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
n := &dummyStruct{
b: t,
}
wg.Done()
}()
}
wg.Wait()
})
b.Run(fmt.Sprintf("AtomicStorePointer/%v", n), func(b *testing.B) {
var wg sync.WaitGroup
var up unsafe.Pointer
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
for j := 0; j < b.N; j++ {
atomic.StorePointer(&up, unsafe.Pointer(&t))
for pb.Next() {
for y := 0; y < j; y++ {
}
for {
v := (*dummyStruct)(atomic.LoadPointer(&dp))
n.a = v.a + 1
if atomic.CompareAndSwapPointer(&dp, unsafe.Pointer(v), unsafe.Pointer(n)) {
n = v
break
}
}
}
wg.Done()
}()
}
wg.Wait()
})
b.Run(fmt.Sprintf("Mutex/%v", n), func(b *testing.B) {
var wg sync.WaitGroup
var c int
mu := sync.Mutex{}
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
for j := 0; j < b.N; j++ {
})
})
}
}
var mu sync.Mutex
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("Mutex/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for y := 0; y < j; y++ {
}
mu.Lock()
c = t
d.a++
d.b = t
mu.Unlock()
}
wg.Done()
}()
}
_ = c
wg.Wait()
})
})
})
}
}
}

View File

@ -139,7 +139,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error
if config.PayloadConfig != nil {
switch config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.WithCodec(byteBufCodec{}))
opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(byteBufCodec{})))
case *testpb.PayloadConfig_SimpleParams:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknow payload config: %v", config.PayloadConfig)

312
vendor/google.golang.org/grpc/call.go generated vendored
View File

@ -19,131 +19,13 @@
package grpc
import (
"io"
"time"
"golang.org/x/net/context"
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
// recvResponse receives and parses an RPC response.
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
// Try to acquire header metadata from the server if there is any.
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header()
if err != nil {
return
}
p := &parser{r: stream}
var inPayload *stats.InPayload
if dopts.copts.StatsHandler != nil {
inPayload = &stats.InPayload{
Client: true,
}
}
for {
if c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
// Set dc if it exists and matches the message compression type used,
// otherwise set comp if a registered compressor exists for it.
var comp encoding.Compressor
var dc Decompressor
if rc := stream.RecvCompress(); dopts.dc != nil && dopts.dc.Type() == rc {
dc = dopts.dc
} else if rc != "" && rc != encoding.Identity {
comp = encoding.GetCompressor(rc)
}
if err = recv(p, dopts.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil {
if err == io.EOF {
break
}
return
}
}
if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK {
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
// Fix the order if necessary.
dopts.copts.StatsHandler.HandleRPC(ctx, inPayload)
}
c.trailerMD = stream.Trailer()
return nil
}
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
var (
outPayload *stats.OutPayload
)
if dopts.copts.StatsHandler != nil {
outPayload = &stats.OutPayload{
Client: true,
}
}
// Set comp and clear compressor if a registered compressor matches the type
// specified via UseCompressor. (And error if a matching compressor is not
// registered.)
var comp encoding.Compressor
if ct := c.compressorType; ct != "" && ct != encoding.Identity {
compressor = nil // Disable the legacy compressor.
comp = encoding.GetCompressor(ct)
if comp == nil {
return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct)
}
}
hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp)
if err != nil {
return err
}
if c.maxSendMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(data) > *c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize)
}
err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
}
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return err
}
// Sent successfully.
return nil
}
// Invoke sends the RPC request on the wire and returns after response is
// received. This is typically called by generated code.
//
// All errors returned by Invoke are compatible with the status package.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
@ -159,188 +41,34 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return cc.Invoke(ctx, method, args, reply, opts...)
}
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo()
mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false}
if mc.Timeout != nil && *mc.Timeout >= 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer cancel()
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
if err := o.before(c); err != nil {
return toRPCErr(err)
}
}
defer func() {
for _, o := range opts {
o.after(c)
}
}()
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
if EnableTracing {
c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
defer c.traceInfo.tr.Finish()
c.traceInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
c.traceInfo.firstLine.deadline = deadline.Sub(time.Now())
}
c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false)
// TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set.
defer func() {
if e != nil {
c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true)
c.traceInfo.tr.SetError()
}
}()
}
ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
FailFast: c.failFast,
}
sh.HandleRPC(ctx, begin)
defer func() {
end := &stats.End{
Client: true,
EndTime: time.Now(),
Error: e,
}
sh.HandleRPC(ctx, end)
}()
}
topts := &transport.Options{
Last: true,
Delay: false,
}
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
}
if c.creds != nil {
callHdr.Creds = c.creds
}
if c.compressorType != "" {
callHdr.SendCompress = c.compressorType
} else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
// TODO: implement retries in clientStream and make this simply
// newClientStream, SendMsg, RecvMsg.
firstAttempt := true
for {
// Check to make sure the context has expired. This will prevent us from
// looping forever if an error occurs for wait-for-ready RPCs where no data
// is sent on the wire.
select {
case <-ctx.Done():
return toRPCErr(ctx.Err())
default:
}
// Record the done handler from Balancer.Get(...). It is called once the
// RPC has completed or failed.
t, done, err := cc.getTransport(ctx, c.failFast)
csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
if done != nil {
done(balancer.DoneInfo{Err: err})
}
// In the event of any error from NewStream, we never attempted to write
// anything to the wire, so we can retry indefinitely for non-fail-fast
// RPCs.
if !c.failFast {
continue
}
return toRPCErr(err)
}
if peer, ok := peer.FromContext(stream.Context()); ok {
c.peer = peer
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts)
if err != nil {
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
}
// Retry a non-failfast RPC when
// i) the server started to drain before this RPC was initiated.
// ii) the server refused the stream.
if !c.failFast && stream.Unprocessed() {
// In this case, the server did not receive the data, but we still
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
// Otherwise, give up and return an error anyway.
}
return toRPCErr(err)
}
err = recvResponse(ctx, cc.dopts, t, c, stream, reply)
if err != nil {
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
}
if !c.failFast && stream.Unprocessed() {
// In these cases, the server did not receive the data, but we still
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
// Otherwise, give up and return an error anyway.
}
return toRPCErr(err)
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
}
t.CloseStream(stream, nil)
err = stream.Status().Err()
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
}
if !c.failFast && stream.Unprocessed() {
// In these cases, the server did not receive the data, but we still
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
cs := csInt.(*clientStream)
if err := cs.SendMsg(req); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
return err
if err := cs.RecvMsg(reply); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
return nil
}
}

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
@ -40,17 +41,17 @@ import (
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
var (
// ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
// DEPRECATED: Please use context.DeadlineExceeded instead.
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
//
// Deprecated: this error should not be relied upon by users; use the status
// code of Canceled instead.
ErrClientConnClosing = status.Error(codes.Canceled, "grpc: the client connection is closing")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
@ -85,7 +86,6 @@ var (
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
@ -99,10 +99,8 @@ type dialOptions struct {
// balancer, and also by WithBalancerName dial option.
balancerBuilder balancer.Builder
// This is to support grpclb.
resolverBuilder resolver.Builder
// Custom user options for resolver.Build.
resolverBuildUserOptions interface{}
waitForHandshake bool
resolverBuilder resolver.Builder
waitForHandshake bool
}
const (
@ -167,10 +165,10 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption {
}
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
//
// Deprecated: use WithDefaultCallOptions(CallCustomCodec(c)) instead.
func WithCodec(c Codec) DialOption {
return func(o *dialOptions) {
o.codec = c
}
return WithDefaultCallOptions(CallCustomCodec(c))
}
// WithCompressor returns a DialOption which sets a Compressor to use for
@ -236,14 +234,6 @@ func withResolverBuilder(b resolver.Builder) DialOption {
}
}
// WithResolverUserOptions returns a DialOption which sets the UserOptions
// field of resolver's BuildOption.
func WithResolverUserOptions(userOpt interface{}) DialOption {
return func(o *dialOptions) {
o.resolverBuildUserOptions = userOpt
}
}
// WithServiceConfig returns a DialOption which has a channel to read the service configuration.
// DEPRECATED: service config should be received through name resolver, as specified here.
// https://github.com/grpc/grpc/blob/master/doc/service_config.md
@ -407,6 +397,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// cancel or expire the pending connection. Once this function returns, the
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
@ -482,10 +476,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
default:
}
}
// Set defaults.
if cc.dopts.codec == nil {
cc.dopts.codec = protoCodec{}
}
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
@ -1119,8 +1109,8 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
}
done := make(chan struct{})
onPrefaceReceipt := func() {
close(done)
ac.mu.Lock()
close(done)
if !ac.backoffDeadline.IsZero() {
// If we haven't already started reconnecting to
// other backends.
@ -1185,10 +1175,16 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
close(ac.ready)
ac.ready = nil
}
ac.connectRetryNum = connectRetryNum
ac.backoffDeadline = backoffDeadline
ac.connectDeadline = connectDeadline
ac.reconnectIdx = i + 1 // Start reconnecting from the next backend in the list.
select {
case <-done:
// If the server has responded back with preface already,
// don't set the reconnect parameters.
default:
ac.connectRetryNum = connectRetryNum
ac.backoffDeadline = backoffDeadline
ac.connectDeadline = connectDeadline
ac.reconnectIdx = i + 1 // Start reconnecting from the next backend in the list.
}
ac.mu.Unlock()
return true, nil
}
@ -1379,3 +1375,10 @@ func (ac *addrConn) getState() connectivity.State {
defer ac.mu.Unlock()
return ac.state
}
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
//
// Deprecated: This error is never returned by grpc and should not be
// referenced by users.
var ErrClientConnTimeout = errors.New("grpc: timed out when dialing")

View File

@ -599,6 +599,41 @@ func TestNonblockingDialWithEmptyBalancer(t *testing.T) {
}
}
func TestResolverServiceConfigBeforeAddressNotPanic(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
// SwitchBalancer before NewAddress. There was no balancer created, this
// makes sure we don't call close on nil balancerWrapper.
r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) // This should not panic.
time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
}
func TestResolverEmptyUpdateNotPanic(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
// This make sure we don't create addrConn with empty address list.
r.NewAddress([]resolver.Address{}) // This should not panic.
time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
}
func TestClientUpdatesParamsAfterGoAway(t *testing.T) {
defer leakcheck.Check(t)
lis, err := net.Listen("tcp", "localhost:0")

View File

@ -19,96 +19,32 @@
package grpc
import (
"math"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
)
// baseCodec contains the functionality of both Codec and encoding.Codec, but
// omits the name/string, which vary between the two and are not needed for
// anything besides the registry in the encoding package.
type baseCodec interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var _ baseCodec = Codec(nil)
var _ baseCodec = encoding.Codec(nil)
// Codec defines the interface gRPC uses to encode and decode messages.
// Note that implementations of this interface must be thread safe;
// a Codec's methods can be called from concurrent goroutines.
//
// Deprecated: use encoding.Codec instead.
type Codec interface {
// Marshal returns the wire format of v.
Marshal(v interface{}) ([]byte, error)
// Unmarshal parses the wire format into v.
Unmarshal(data []byte, v interface{}) error
// String returns the name of the Codec implementation. The returned
// string will be used as part of content type in transmission.
// String returns the name of the Codec implementation. This is unused by
// gRPC.
String() string
}
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
type protoCodec struct {
}
type cachedProtoBuffer struct {
lastMarshaledSize uint32
proto.Buffer
}
func capToMaxInt32(val int) uint32 {
if val > math.MaxInt32 {
return uint32(math.MaxInt32)
}
return uint32(val)
}
func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
protoMsg := v.(proto.Message)
newSlice := make([]byte, 0, cb.lastMarshaledSize)
cb.SetBuf(newSlice)
cb.Reset()
if err := cb.Marshal(protoMsg); err != nil {
return nil, err
}
out := cb.Bytes()
cb.lastMarshaledSize = capToMaxInt32(len(out))
return out, nil
}
func (p protoCodec) Marshal(v interface{}) ([]byte, error) {
if pm, ok := v.(proto.Marshaler); ok {
// object can marshal itself, no need for buffer
return pm.Marshal()
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
out, err := p.marshal(v, cb)
// put back buffer and lose the ref to the slice
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return out, err
}
func (p protoCodec) Unmarshal(data []byte, v interface{}) error {
protoMsg := v.(proto.Message)
protoMsg.Reset()
if pu, ok := protoMsg.(proto.Unmarshaler); ok {
// object can unmarshal itself, no need for buffer
return pu.Unmarshal(data)
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data)
err := cb.Unmarshal(protoMsg)
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return err
}
func (protoCodec) String() string {
return "proto"
}
var protoBufferPool = &sync.Pool{
New: func() interface{} {
return &cachedProtoBuffer{
Buffer: proto.Buffer{},
lastMarshaledSize: 16,
}
},
}

View File

@ -19,110 +19,14 @@
package grpc
import (
"bytes"
"sync"
"testing"
"google.golang.org/grpc/test/codec_perf"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
)
func marshalAndUnmarshal(t *testing.T, protoCodec Codec, expectedBody []byte) {
p := &codec_perf.Buffer{}
p.Body = expectedBody
marshalledBytes, err := protoCodec.Marshal(p)
if err != nil {
t.Errorf("protoCodec.Marshal(_) returned an error")
}
if err := protoCodec.Unmarshal(marshalledBytes, p); err != nil {
t.Errorf("protoCodec.Unmarshal(_) returned an error")
}
if bytes.Compare(p.GetBody(), expectedBody) != 0 {
t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody)
}
}
func TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
marshalAndUnmarshal(t, protoCodec{}, []byte{1, 2, 3})
}
// Try to catch possible race conditions around use of pools
func TestConcurrentUsage(t *testing.T) {
const (
numGoRoutines = 100
numMarshUnmarsh = 1000
)
// small, arbitrary byte slices
protoBodies := [][]byte{
[]byte("one"),
[]byte("two"),
[]byte("three"),
[]byte("four"),
[]byte("five"),
}
var wg sync.WaitGroup
codec := protoCodec{}
for i := 0; i < numGoRoutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for k := 0; k < numMarshUnmarsh; k++ {
marshalAndUnmarshal(t, codec, protoBodies[k%len(protoBodies)])
}
}()
}
wg.Wait()
}
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
// stomped on during reuse of a proto.Buffer.
func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
codec1 := protoCodec{}
codec2 := protoCodec{}
expectedBody1 := []byte{1, 2, 3}
expectedBody2 := []byte{4, 5, 6}
proto1 := codec_perf.Buffer{Body: expectedBody1}
proto2 := codec_perf.Buffer{Body: expectedBody2}
var m1, m2 []byte
var err error
if m1, err = codec1.Marshal(&proto1); err != nil {
t.Errorf("protoCodec.Marshal(%v) failed", proto1)
}
if m2, err = codec2.Marshal(&proto2); err != nil {
t.Errorf("protoCodec.Marshal(%v) failed", proto2)
}
if err = codec1.Unmarshal(m1, &proto1); err != nil {
t.Errorf("protoCodec.Unmarshal(%v) failed", m1)
}
if err = codec2.Unmarshal(m2, &proto2); err != nil {
t.Errorf("protoCodec.Unmarshal(%v) failed", m2)
}
b1 := proto1.GetBody()
b2 := proto2.GetBody()
for i, v := range b1 {
if expectedBody1[i] != v {
t.Errorf("expected %v at index %v but got %v", i, expectedBody1[i], v)
}
}
for i, v := range b2 {
if expectedBody2[i] != v {
t.Errorf("expected %v at index %v but got %v", i, expectedBody2[i], v)
}
func TestGetCodecForProtoIsNotNil(t *testing.T) {
if encoding.GetCodec(proto.Name) == nil {
t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name)
}
}

View File

@ -19,6 +19,7 @@
// Package codes defines the canonical error codes used by gRPC. It is
// consistent across various languages.
package codes // import "google.golang.org/grpc/codes"
import (
"fmt"
)
@ -33,9 +34,9 @@ const (
// Canceled indicates the operation was canceled (typically by the caller).
Canceled Code = 1
// Unknown error. An example of where this error may be returned is
// Unknown error. An example of where this error may be returned is
// if a Status value received from another address space belongs to
// an error-space that is not known in this address space. Also
// an error-space that is not known in this address space. Also
// errors raised by APIs that do not return enough error information
// may be converted to this error.
Unknown Code = 2
@ -64,15 +65,11 @@ const (
// PermissionDenied indicates the caller does not have permission to
// execute the specified operation. It must not be used for rejections
// caused by exhausting some resource (use ResourceExhausted
// instead for those errors). It must not be
// instead for those errors). It must not be
// used if the caller cannot be identified (use Unauthenticated
// instead for those errors).
PermissionDenied Code = 7
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
// ResourceExhausted indicates some resource has been exhausted, perhaps
// a per-user quota, or perhaps the entire file system is out of space.
ResourceExhausted Code = 8
@ -88,7 +85,7 @@ const (
// (b) Use Aborted if the client should retry at a higher-level
// (e.g., restarting a read-modify-write sequence).
// (c) Use FailedPrecondition if the client should not retry until
// the system state has been explicitly fixed. E.g., if an "rmdir"
// the system state has been explicitly fixed. E.g., if an "rmdir"
// fails because the directory is non-empty, FailedPrecondition
// should be returned since the client should not retry unless
// they have first fixed up the directory by deleting files from it.
@ -117,7 +114,7 @@ const (
// file size.
//
// There is a fair bit of overlap between FailedPrecondition and
// OutOfRange. We recommend using OutOfRange (the more specific
// OutOfRange. We recommend using OutOfRange (the more specific
// error) when it applies so that callers who are iterating through
// a space can easily look for an OutOfRange error to detect when
// they are done.
@ -127,8 +124,8 @@ const (
// supported/enabled in this service.
Unimplemented Code = 12
// Internal errors. Means some invariants expected by underlying
// system has been broken. If you see one of these errors,
// Internal errors. Means some invariants expected by underlying
// system has been broken. If you see one of these errors,
// something is very broken.
Internal Code = 13
@ -142,6 +139,10 @@ const (
// DataLoss indicates unrecoverable data loss or corruption.
DataLoss Code = 15
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
)
var strToCode = map[string]Code{

View File

@ -43,8 +43,9 @@ type PerRPCCredentials interface {
// GetRequestMetadata gets the current request metadata, refreshing
// tokens if required. This should be called by the transport layer on
// each request, and the data should be populated in headers or other
// context. uri is the URI of the entry point for the request. When
// supported by the underlying implementation, ctx can be used for
// context. If a status code is returned, it will be used as the status
// for the RPC. uri is the URI of the entry point for the request.
// When supported by the underlying implementation, ctx can be used for
// timeout and cancellation.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string.

View File

@ -16,46 +16,103 @@
*
*/
// Package encoding defines the interface for the compressor and the functions
// to register and get the compossor.
// Package encoding defines the interface for the compressor and codec, and
// functions to register and retrieve compressors and codecs.
//
// This package is EXPERIMENTAL.
package encoding
import (
"io"
"strings"
)
var registerCompressor = make(map[string]Compressor)
// Compressor is used for compressing and decompressing when sending or receiving messages.
type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an error
// occurs while initializing the compressor, that error is returned instead.
Compress(w io.Writer) (io.WriteCloser, error)
// Decompress reads data from r, decompresses it, and provides the uncompressed data
// via the returned io.Reader. If an error occurs while initializing the decompressor, that error
// is returned instead.
Decompress(r io.Reader) (io.Reader, error)
// Name is the name of the compression codec and is used to set the content coding header.
Name() string
}
// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when
// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a
// message based on the content coding header. Servers also use it to send a response with the
// same encoding as the request.
//
// NOTE: this function must only be called during initialization time (i.e. in an init() function). If
// multiple Compressors are registered with the same name, the one registered last will take effect.
func RegisterCompressor(c Compressor) {
registerCompressor[c.Name()] = c
}
// GetCompressor returns Compressor for the given compressor name.
func GetCompressor(name string) Compressor {
return registerCompressor[name]
}
// Identity specifies the optional encoding for uncompressed streams.
// It is intended for grpc internal use only.
const Identity = "identity"
// Compressor is used for compressing and decompressing when sending or
// receiving messages.
type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an
// error occurs while initializing the compressor, that error is returned
// instead.
Compress(w io.Writer) (io.WriteCloser, error)
// Decompress reads data from r, decompresses it, and provides the
// uncompressed data via the returned io.Reader. If an error occurs while
// initializing the decompressor, that error is returned instead.
Decompress(r io.Reader) (io.Reader, error)
// Name is the name of the compression codec and is used to set the content
// coding header. The result must be static; the result cannot change
// between calls.
Name() string
}
var registeredCompressor = make(map[string]Compressor)
// RegisterCompressor registers the compressor with gRPC by its name. It can
// be activated when sending an RPC via grpc.UseCompressor(). It will be
// automatically accessed when receiving a message based on the content coding
// header. Servers also use it to send a response with the same encoding as
// the request.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Compressors are
// registered with the same name, the one registered last will take effect.
func RegisterCompressor(c Compressor) {
registeredCompressor[c.Name()] = c
}
// GetCompressor returns Compressor for the given compressor name.
func GetCompressor(name string) Compressor {
return registeredCompressor[name]
}
// Codec defines the interface gRPC uses to encode and decode messages. Note
// that implementations of this interface must be thread safe; a Codec's
// methods can be called from concurrent goroutines.
type Codec interface {
// Marshal returns the wire format of v.
Marshal(v interface{}) ([]byte, error)
// Unmarshal parses the wire format into v.
Unmarshal(data []byte, v interface{}) error
// Name returns the name of the Codec implementation. The returned string
// will be used as part of content type in transmission. The result must be
// static; the result cannot change between calls.
Name() string
}
var registeredCodecs = make(map[string]Codec, 0)
// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.
//
// The Codec will be stored and looked up by result of its Name() method, which
// should match the content-subtype of the encoding handled by the Codec. This
// is case-insensitive, and is stored and looked up as lowercase. If the
// result of calling Name() is an empty string, RegisterCodec will panic. See
// Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Compressors are
// registered with the same name, the one registered last will take effect.
func RegisterCodec(codec Codec) {
if codec == nil {
panic("cannot register a nil Codec")
}
contentSubtype := strings.ToLower(codec.Name())
if contentSubtype == "" {
panic("cannot register Codec with empty string result for String()")
}
registeredCodecs[contentSubtype] = codec
}
// GetCodec gets a registered Codec by content-subtype, or nil if no Codec is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype]
}

View File

@ -30,6 +30,9 @@ import (
"google.golang.org/grpc/encoding"
)
// Name is the name registered for the gzip compressor.
const Name = "gzip"
func init() {
c := &compressor{}
c.poolCompressor.New = func() interface{} {
@ -84,7 +87,7 @@ func (z *reader) Read(p []byte) (n int, err error) {
}
func (c *compressor) Name() string {
return "gzip"
return Name
}
type compressor struct {

110
vendor/google.golang.org/grpc/encoding/proto/proto.go generated vendored Normal file
View File

@ -0,0 +1,110 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package proto defines the protobuf codec. Importing this package will
// register the codec.
package proto
import (
"math"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/encoding"
)
// Name is the name registered for the proto compressor.
const Name = "proto"
func init() {
encoding.RegisterCodec(codec{})
}
// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{}
type cachedProtoBuffer struct {
lastMarshaledSize uint32
proto.Buffer
}
func capToMaxInt32(val int) uint32 {
if val > math.MaxInt32 {
return uint32(math.MaxInt32)
}
return uint32(val)
}
func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
protoMsg := v.(proto.Message)
newSlice := make([]byte, 0, cb.lastMarshaledSize)
cb.SetBuf(newSlice)
cb.Reset()
if err := cb.Marshal(protoMsg); err != nil {
return nil, err
}
out := cb.Bytes()
cb.lastMarshaledSize = capToMaxInt32(len(out))
return out, nil
}
func (codec) Marshal(v interface{}) ([]byte, error) {
if pm, ok := v.(proto.Marshaler); ok {
// object can marshal itself, no need for buffer
return pm.Marshal()
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
out, err := marshal(v, cb)
// put back buffer and lose the ref to the slice
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return out, err
}
func (codec) Unmarshal(data []byte, v interface{}) error {
protoMsg := v.(proto.Message)
protoMsg.Reset()
if pu, ok := protoMsg.(proto.Unmarshaler); ok {
// object can unmarshal itself, no need for buffer
return pu.Unmarshal(data)
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data)
err := cb.Unmarshal(protoMsg)
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return err
}
func (codec) Name() string {
return Name
}
var protoBufferPool = &sync.Pool{
New: func() interface{} {
return &cachedProtoBuffer{
Buffer: proto.Buffer{},
lastMarshaledSize: 16,
}
},
}

View File

@ -18,13 +18,14 @@
*
*/
package grpc
package proto
import (
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/test/codec_perf"
)
@ -68,7 +69,7 @@ func BenchmarkProtoCodec(b *testing.B) {
protoStructs := setupBenchmarkProtoCodecInputs(s)
name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p)
b.Run(name, func(b *testing.B) {
codec := &protoCodec{}
codec := &codec{}
b.SetParallelism(p)
b.RunParallel(func(pb *testing.PB) {
benchmarkProtoCodec(codec, protoStructs, pb, b)
@ -78,7 +79,7 @@ func BenchmarkProtoCodec(b *testing.B) {
}
}
func benchmarkProtoCodec(codec *protoCodec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
counter := 0
for pb.Next() {
counter++
@ -87,13 +88,13 @@ func benchmarkProtoCodec(codec *protoCodec, protoStructs []proto.Message, pb *te
}
}
func fastMarshalAndUnmarshal(protoCodec Codec, protoStruct proto.Message, b *testing.B) {
marshaledBytes, err := protoCodec.Marshal(protoStruct)
func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) {
marshaledBytes, err := codec.Marshal(protoStruct)
if err != nil {
b.Errorf("protoCodec.Marshal(_) returned an error")
b.Errorf("codec.Marshal(_) returned an error")
}
res := codec_perf.Buffer{}
if err := protoCodec.Unmarshal(marshaledBytes, &res); err != nil {
b.Errorf("protoCodec.Unmarshal(_) returned an error")
if err := codec.Unmarshal(marshaledBytes, &res); err != nil {
b.Errorf("codec.Unmarshal(_) returned an error")
}
}

View File

@ -0,0 +1,129 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package proto
import (
"bytes"
"sync"
"testing"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/test/codec_perf"
)
func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) {
p := &codec_perf.Buffer{}
p.Body = expectedBody
marshalledBytes, err := codec.Marshal(p)
if err != nil {
t.Errorf("codec.Marshal(_) returned an error")
}
if err := codec.Unmarshal(marshalledBytes, p); err != nil {
t.Errorf("codec.Unmarshal(_) returned an error")
}
if bytes.Compare(p.GetBody(), expectedBody) != 0 {
t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody)
}
}
func TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3})
}
// Try to catch possible race conditions around use of pools
func TestConcurrentUsage(t *testing.T) {
const (
numGoRoutines = 100
numMarshUnmarsh = 1000
)
// small, arbitrary byte slices
protoBodies := [][]byte{
[]byte("one"),
[]byte("two"),
[]byte("three"),
[]byte("four"),
[]byte("five"),
}
var wg sync.WaitGroup
codec := codec{}
for i := 0; i < numGoRoutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for k := 0; k < numMarshUnmarsh; k++ {
marshalAndUnmarshal(t, codec, protoBodies[k%len(protoBodies)])
}
}()
}
wg.Wait()
}
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
// stomped on during reuse of a proto.Buffer.
func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
codec1 := codec{}
codec2 := codec{}
expectedBody1 := []byte{1, 2, 3}
expectedBody2 := []byte{4, 5, 6}
proto1 := codec_perf.Buffer{Body: expectedBody1}
proto2 := codec_perf.Buffer{Body: expectedBody2}
var m1, m2 []byte
var err error
if m1, err = codec1.Marshal(&proto1); err != nil {
t.Errorf("codec.Marshal(%v) failed", proto1)
}
if m2, err = codec2.Marshal(&proto2); err != nil {
t.Errorf("codec.Marshal(%v) failed", proto2)
}
if err = codec1.Unmarshal(m1, &proto1); err != nil {
t.Errorf("codec.Unmarshal(%v) failed", m1)
}
if err = codec2.Unmarshal(m2, &proto2); err != nil {
t.Errorf("codec.Unmarshal(%v) failed", m2)
}
b1 := proto1.GetBody()
b2 := proto2.GetBody()
for i, v := range b1 {
if expectedBody1[i] != v {
t.Errorf("expected %v at index %v but got %v", i, expectedBody1[i], v)
}
}
for i, v := range b2 {
if expectedBody2[i] != v {
t.Errorf("expected %v at index %v but got %v", i, expectedBody2[i], v)
}
}
}

View File

@ -101,12 +101,12 @@ func runRecordRoute(client pb.RouteGuideClient) {
// runRouteChat receives a sequence of route notes, while sending notes for various locations.
func runRouteChat(client pb.RouteGuideClient) {
notes := []*pb.RouteNote{
{&pb.Point{Latitude: 0, Longitude: 1}, "First message"},
{&pb.Point{Latitude: 0, Longitude: 2}, "Second message"},
{&pb.Point{Latitude: 0, Longitude: 3}, "Third message"},
{&pb.Point{Latitude: 0, Longitude: 1}, "Fourth message"},
{&pb.Point{Latitude: 0, Longitude: 2}, "Fifth message"},
{&pb.Point{Latitude: 0, Longitude: 3}, "Sixth message"},
{Location: &pb.Point{Latitude: 0, Longitude: 1}, Message: "First message"},
{Location: &pb.Point{Latitude: 0, Longitude: 2}, Message: "Second message"},
{Location: &pb.Point{Latitude: 0, Longitude: 3}, Message: "Third message"},
{Location: &pb.Point{Latitude: 0, Longitude: 1}, Message: "Fourth message"},
{Location: &pb.Point{Latitude: 0, Longitude: 2}, Message: "Fifth message"},
{Location: &pb.Point{Latitude: 0, Longitude: 3}, Message: "Sixth message"},
}
stream, err := client.RouteChat(context.Background())
if err != nil {

View File

@ -48,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok {
return err
}
@ -62,8 +65,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())

View File

@ -49,6 +49,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok {
return err
}
@ -63,8 +66,6 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())

View File

@ -241,10 +241,8 @@ func DoTimeoutOnSleepingServer(tc testpb.TestServiceClient, args ...grpc.CallOpt
ResponseType: testpb.PayloadType_COMPRESSABLE,
Payload: pl,
}
if err := stream.Send(req); err != nil {
if status.Code(err) != codes.DeadlineExceeded {
grpclog.Fatalf("%v.Send(_) = %v", stream, err)
}
if err := stream.Send(req); err != nil && err != io.EOF {
grpclog.Fatalf("%v.Send(_) = %v", stream, err)
}
if _, err := stream.Recv(); status.Code(err) != codes.DeadlineExceeded {
grpclog.Fatalf("%v.Recv() = _, %v, want error code %d", stream, err, codes.DeadlineExceeded)

View File

@ -17,7 +17,8 @@
*/
// Package metadata define the structure of the metadata supported by gRPC library.
// Please refer to https://grpc.io/docs/guides/wire.html for more information about custom-metadata.
// Please refer to https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
// for more information about custom-metadata.
package metadata // import "google.golang.org/grpc/metadata"
import (
@ -115,9 +116,22 @@ func NewIncomingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdIncomingKey{}, md)
}
// NewOutgoingContext creates a new context with outgoing md attached.
// NewOutgoingContext creates a new context with outgoing md attached. If used
// in conjunction with AppendToOutgoingContext, NewOutgoingContext will
// overwrite any previously-appended metadata.
func NewOutgoingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdOutgoingKey{}, md)
return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md})
}
// AppendToOutgoingContext returns a new context with the provided kv merged
// with any existing metadata in the context. Please refer to the
// documentation of Pairs for a description of kv.
func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context {
if len(kv)%2 == 1 {
panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv)))
}
md, _ := ctx.Value(mdOutgoingKey{}).(rawMD)
return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: append(md.added, kv)})
}
// FromIncomingContext returns the incoming metadata in ctx if it exists. The
@ -128,10 +142,39 @@ func FromIncomingContext(ctx context.Context) (md MD, ok bool) {
return
}
// FromOutgoingContextRaw returns the un-merged, intermediary contents
// of rawMD. Remember to perform strings.ToLower on the keys. The returned
// MD should not be modified. Writing to it may cause races. Modification
// should be made to copies of the returned MD.
//
// This is intended for gRPC-internal use ONLY.
func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) {
raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD)
if !ok {
return nil, nil, false
}
return raw.md, raw.added, true
}
// FromOutgoingContext returns the outgoing metadata in ctx if it exists. The
// returned MD should not be modified. Writing to it may cause races.
// Modification should be made to the copies of the returned MD.
func FromOutgoingContext(ctx context.Context) (md MD, ok bool) {
md, ok = ctx.Value(mdOutgoingKey{}).(MD)
return
func FromOutgoingContext(ctx context.Context) (MD, bool) {
raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD)
if !ok {
return nil, false
}
mds := make([]MD, 0, len(raw.added)+1)
mds = append(mds, raw.md)
for _, vv := range raw.added {
mds = append(mds, Pairs(vv...))
}
return Join(mds...), ok
}
type rawMD struct {
md MD
added [][]string
}

View File

@ -21,6 +21,8 @@ package metadata
import (
"reflect"
"testing"
"golang.org/x/net/context"
)
func TestPairsMD(t *testing.T) {
@ -69,3 +71,55 @@ func TestJoin(t *testing.T) {
}
}
}
func TestAppendToOutgoingContext(t *testing.T) {
// Pre-existing metadata
ctx := NewOutgoingContext(context.Background(), Pairs("k1", "v1", "k2", "v2"))
ctx = AppendToOutgoingContext(ctx, "k1", "v3")
ctx = AppendToOutgoingContext(ctx, "k1", "v4")
md, ok := FromOutgoingContext(ctx)
if !ok {
t.Errorf("Expected MD to exist in ctx, but got none")
}
want := Pairs("k1", "v1", "k1", "v3", "k1", "v4", "k2", "v2")
if !reflect.DeepEqual(md, want) {
t.Errorf("context's metadata is %v, want %v", md, want)
}
// No existing metadata
ctx = AppendToOutgoingContext(context.Background(), "k1", "v1")
md, ok = FromOutgoingContext(ctx)
if !ok {
t.Errorf("Expected MD to exist in ctx, but got none")
}
want = Pairs("k1", "v1")
if !reflect.DeepEqual(md, want) {
t.Errorf("context's metadata is %v, want %v", md, want)
}
}
// Old/slow approach to adding metadata to context
func Benchmark_AddingMetadata_ContextManipulationApproach(b *testing.B) {
for n := 0; n < b.N; n++ {
ctx := context.Background()
md, _ := FromOutgoingContext(ctx)
NewOutgoingContext(ctx, Join(Pairs("k1", "v1", "k2", "v2"), md))
}
}
// Newer/faster approach to adding metadata to context
func BenchmarkAppendToOutgoingContext(b *testing.B) {
for n := 0; n < b.N; n++ {
AppendToOutgoingContext(context.Background(), "k1", "v1", "k2", "v2")
}
}
func BenchmarkFromOutgoingContext(b *testing.B) {
ctx := context.Background()
ctx = NewOutgoingContext(ctx, MD{"k3": {"v3", "v4"}})
ctx = AppendToOutgoingContext(ctx, "k1", "v1", "k2", "v2")
for n := 0; n < b.N; n++ {
FromOutgoingContext(ctx)
}
}

View File

@ -90,9 +90,6 @@ type Address struct {
// BuildOption includes additional information for the builder to create
// the resolver.
type BuildOption struct {
// UserOptions can be used to pass configuration between DialOptions and the
// resolver.
UserOptions interface{}
}
// ClientConn contains the callbacks for resolver to notify any updates

View File

@ -83,9 +83,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
}
var err error
ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{
UserOptions: cc.dopts.resolverBuildUserOptions,
})
ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{})
if err != nil {
return nil, err
}

View File

@ -1,99 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"strings"
"testing"
"time"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/test/leakcheck"
)
func TestResolverServiceConfigBeforeAddressNotPanic(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
// SwitchBalancer before NewAddress. There was no balancer created, this
// makes sure we don't call close on nil balancerWrapper.
r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) // This should not panic.
time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
}
func TestResolverEmptyUpdateNotPanic(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
// This make sure we don't create addrConn with empty address list.
r.NewAddress([]resolver.Address{}) // This should not panic.
time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
}
var (
errTestResolverFailBuild = fmt.Errorf("test resolver build error")
)
type testResolverFailBuilder struct {
buildOpt resolver.BuildOption
}
func (r *testResolverFailBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
r.buildOpt = opts
return nil, errTestResolverFailBuild
}
func (r *testResolverFailBuilder) Scheme() string {
return "testResolverFailBuilderScheme"
}
// Tests that options in WithResolverUserOptions are passed to resolver.Build().
func TestResolverUserOptions(t *testing.T) {
r := &testResolverFailBuilder{}
userOpt := "testUserOpt"
_, err := Dial("scheme:///test.server", WithInsecure(),
withResolverBuilder(r),
WithResolverUserOptions(userOpt),
)
if err == nil || !strings.Contains(err.Error(), errTestResolverFailBuild.Error()) {
t.Fatalf("Dial with testResolverFailBuilder returns err: %v, want: %v", err, errTestResolverFailBuild)
}
if r.buildOpt.UserOptions != userOpt {
t.Fatalf("buildOpt.UserOptions = %T %+v, want %v", r.buildOpt.UserOptions, r.buildOpt.UserOptions, userOpt)
}
}

View File

@ -25,6 +25,7 @@ import (
"io"
"io/ioutil"
"math"
"strings"
"sync"
"time"
@ -32,6 +33,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@ -125,13 +127,13 @@ func (d *gzipDecompressor) Type() string {
type callInfo struct {
compressorType string
failFast bool
headerMD metadata.MD
trailerMD metadata.MD
peer *peer.Peer
stream *transport.Stream
traceInfo traceInfo // in trace.go
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
}
func defaultCallInfo() *callInfo {
@ -172,7 +174,9 @@ func (o afterCall) after(c *callInfo) { o(c) }
// for a unary RPC.
func Header(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) {
*md = c.headerMD
if c.stream != nil {
*md, _ = c.stream.Header()
}
})
}
@ -180,16 +184,20 @@ func Header(md *metadata.MD) CallOption {
// for a unary RPC.
func Trailer(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) {
*md = c.trailerMD
if c.stream != nil {
*md = c.stream.Trailer()
}
})
}
// Peer returns a CallOption that retrieves peer information for a
// unary RPC.
func Peer(peer *peer.Peer) CallOption {
func Peer(p *peer.Peer) CallOption {
return afterCall(func(c *callInfo) {
if c.peer != nil {
*peer = *c.peer
if c.stream != nil {
if x, ok := peer.FromContext(c.stream.Context()); ok {
*p = *x
}
}
})
}
@ -248,6 +256,49 @@ func UseCompressor(name string) CallOption {
})
}
// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
// to lowercase before being included in Content-Type. See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If CallCustomCodec is not also used, the content-subtype will be used to
// look up the Codec to use in the registry controlled by RegisterCodec. See
// the documention on RegisterCodec for details on registration. The lookup
// of content-subtype is case-insensitive. If no such Codec is found, the call
// will result in an error with code codes.Internal.
//
// If CallCustomCodec is also used, that Codec will be used for all request and
// response messages, with the content-subtype set to the given contentSubtype
// here for requests.
func CallContentSubtype(contentSubtype string) CallOption {
contentSubtype = strings.ToLower(contentSubtype)
return beforeCall(func(c *callInfo) error {
c.contentSubtype = contentSubtype
return nil
})
}
// CallCustomCodec returns a CallOption that will set the given Codec to be
// used for all request and response messages for a call. The result of calling
// String() will be used as the content-subtype in a case-insensitive manner.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead.
func CallCustomCodec(codec Codec) CallOption {
return beforeCall(func(c *callInfo) error {
c.codec = codec
return nil
})
}
// The format of the payload: compressed or not?
type payloadFormat uint8
@ -263,8 +314,8 @@ type parser struct {
// error types.
r io.Reader
// The header of a gRPC message. Find more detail
// at https://grpc.io/docs/guides/wire.html.
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte
}
@ -313,7 +364,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
// encode serializes msg and returns a buffer of message header and a buffer of msg.
// If msg is nil, it generates the message header and an empty msg buffer.
// TODO(ddyihai): eliminate extra Compressor parameter.
func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
var (
b []byte
cbuf *bytes.Buffer
@ -390,7 +441,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return err
@ -485,6 +536,27 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
return status.Errorf(c, format, a...)
}
// setCallInfoCodec should only be called after CallOptions have been applied.
func setCallInfoCodec(c *callInfo) error {
if c.codec != nil {
// codec was already set by a CallOption; use it.
return nil
}
if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default.
c.codec = encoding.GetCodec(proto.Name)
return nil
}
// c.contentSubtype is already lowercased in CallContentSubtype
c.codec = encoding.GetCodec(c.contentSubtype)
if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
}
return nil
}
// The SupportPackageIsVersion variables are referenced from generated protocol
// buffer files to ensure compatibility with the gRPC version used. The latest
// support package version is 5.
@ -500,6 +572,6 @@ const (
)
// Version is the current grpc version.
const Version = "1.9.0"
const Version = "1.10.0"
const grpcUA = "grpc-go/" + Version

View File

@ -27,6 +27,8 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
protoenc "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/status"
perfpb "google.golang.org/grpc/test/codec_perf"
"google.golang.org/grpc/transport"
@ -110,7 +112,7 @@ func TestEncode(t *testing.T) {
}{
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
} {
hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil)
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
}
@ -164,13 +166,14 @@ func TestToRPCErr(t *testing.T) {
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
// bytes.
func bmEncode(b *testing.B, mSize int) {
cdc := encoding.GetCodec(protoenc.Name)
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil)
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil)
encodedSz := int64(len(encodeHdr) + len(encodeData))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encode(protoCodec{}, msg, nil, nil, nil)
encode(cdc, msg, nil, nil, nil)
}
b.SetBytes(encodedSz)
}

View File

@ -40,6 +40,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/keepalive"
@ -105,7 +106,7 @@ type Server struct {
type options struct {
creds credentials.TransportCredentials
codec Codec
codec baseCodec
cp Compressor
dc Decompressor
unaryInt UnaryServerInterceptor
@ -182,6 +183,8 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
}
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
//
// This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
func CustomCodec(codec Codec) ServerOption {
return func(o *options) {
o.codec = codec
@ -327,10 +330,6 @@ func NewServer(opt ...ServerOption) *Server {
for _, o := range opt {
o(&opts)
}
if opts.codec == nil {
// Set the default codec.
opts.codec = protoCodec{}
}
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
@ -695,7 +694,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) {
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
// and subject to change.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r)
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -759,7 +758,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp)
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err)
return err
@ -904,7 +903,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// java implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
if inPayload != nil {
@ -996,7 +995,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo,
@ -1262,6 +1261,22 @@ func init() {
}
}
// contentSubtype must be lowercase
// cannot return nil
func (s *Server) getCodec(contentSubtype string) baseCodec {
if s.opts.codec != nil {
return s.opts.codec
}
if contentSubtype == "" {
return encoding.GetCodec(proto.Name)
}
codec := encoding.GetCodec(contentSubtype)
if codec == nil {
return encoding.GetCodec(proto.Name)
}
return codec
}
// SetHeader sets the header metadata.
// When called multiple times, all the provided metadata will be merged.
// All the metadata will be sent out when one of the following happens:

View File

@ -256,11 +256,10 @@ const (
)
type rpcConfig struct {
count int // Number of requests and responses for streaming RPCs.
success bool // Whether the RPC should succeed or return error.
failfast bool
callType rpcType // Type of RPC.
noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs.
count int // Number of requests and responses for streaming RPCs.
success bool // Whether the RPC should succeed or return error.
failfast bool
callType rpcType // Type of RPC.
}
func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
@ -313,14 +312,8 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
if err = stream.CloseSend(); err != nil && err != io.EOF {
return reqs, resps, err
}
if !c.noLastRecv {
if _, err = stream.Recv(); err != io.EOF {
return reqs, resps, err
}
} else {
// In the case of not calling the last recv, sleep to avoid
// returning too fast to miss the remaining stats (InTrailer and End).
time.Sleep(time.Second)
if _, err = stream.Recv(); err != io.EOF {
return reqs, resps, err
}
return reqs, resps, nil
@ -651,7 +644,7 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) {
actual, ok := status.FromError(st.Error)
if !ok {
t.Fatalf("expected st.Error to be a statusError, got %T", st.Error)
t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
}
expectedStatus, _ := status.FromError(e.err)
@ -1222,20 +1215,6 @@ func TestClientStatsFullDuplexRPCError(t *testing.T) {
})
}
// If the user doesn't call the last recv() on clientStream.
func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) {
count := 1
testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, count},
inHeader: {checkInHeader, 1},
inPayload: {checkInPayload, count},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
})
}
func TestTags(t *testing.T) {
b := []byte{5, 2, 4, 3, 1}
ctx := stats.SetTags(context.Background(), b)

View File

@ -120,7 +120,8 @@ func FromProto(s *spb.Status) *Status {
}
// FromError returns a Status representing err if it was produced from this
// package, otherwise it returns nil, false.
// package. Otherwise, ok is false and a Status is returned with codes.Unknown
// and the original error message.
func FromError(err error) (s *Status, ok bool) {
if err == nil {
return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true
@ -128,7 +129,14 @@ func FromError(err error) (s *Status, ok bool) {
if se, ok := err.(*statusError); ok {
return se.status(), true
}
return nil, false
return New(codes.Unknown, err.Error()), false
}
// Convert is a convenience function which removes the need to handle the
// boolean return value from FromError.
func Convert(err error) *Status {
s, _ := FromError(err)
return s
}
// WithDetails returns a new status with the provided details messages appended to the status.

View File

@ -119,6 +119,33 @@ func TestFromErrorOK(t *testing.T) {
}
}
func TestFromErrorUnknownError(t *testing.T) {
code, message := codes.Unknown, "unknown error"
err := errors.New("unknown error")
s, ok := FromError(err)
if ok || s.Code() != code || s.Message() != message {
t.Fatalf("FromError(%v) = %v, %v; want <Code()=%s, Message()=%q>, false", err, s, ok, code, message)
}
}
func TestConvertKnownError(t *testing.T) {
code, message := codes.Internal, "test description"
err := Error(code, message)
s := Convert(err)
if s.Code() != code || s.Message() != message {
t.Fatalf("Convert(%v) = %v; want <Code()=%s, Message()=%q>", err, s, code, message)
}
}
func TestConvertUnknownError(t *testing.T) {
code, message := codes.Unknown, "unknown error"
err := errors.New("unknown error")
s := Convert(err)
if s.Code() != code || s.Message() != message {
t.Fatalf("Convert(%v) = %v; want <Code()=%s, Message()=%q>", err, s, code, message)
}
}
func TestStatus_ErrorDetails(t *testing.T) {
tests := []struct {
code codes.Code

View File

@ -30,7 +30,6 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
@ -51,6 +50,8 @@ type StreamDesc struct {
}
// Stream defines the common interface a client or server stream has to satisfy.
//
// All errors returned from Stream are compatible with the status package.
type Stream interface {
// Context returns the context for this stream.
Context() context.Context
@ -89,8 +90,9 @@ type ClientStream interface {
// Stream.SendMsg() may return a non-nil error when something wrong happens sending
// the request. The returned error indicates the status of this sending, not the final
// status of the RPC.
// Always call Stream.RecvMsg() to get the final status if you care about the status of
// the RPC.
//
// Always call Stream.RecvMsg() to drain the stream and get the final
// status, otherwise there could be leaked resources.
Stream
}
@ -112,26 +114,28 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
cancel context.CancelFunc
)
c := defaultCallInfo()
mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
// Possible context leak:
// The cancel function for the child context we create will only be called
// when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
// an error is generated by SendMsg.
// https://github.com/grpc/grpc-go/issues/1818.
var cancel context.CancelFunc
if mc.Timeout != nil && *mc.Timeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer func() {
if err != nil {
cancel()
}
}()
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer func() {
if err != nil {
cancel()
}
}()
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
@ -141,6 +145,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
if err := setCallInfoCodec(c); err != nil {
return nil, err
}
callHdr := &transport.CallHdr{
Host: cc.authority,
@ -149,7 +156,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// so we don't flush the header.
// If it's client streaming, the user may never send a request or send it any
// time soon, so we ask the transport to flush the header.
Flush: desc.ClientStreams,
Flush: desc.ClientStreams,
ContentSubtype: c.contentSubtype,
}
// Set our outgoing compression according to the UseCompressor CallOption, if
@ -214,6 +222,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}()
}
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
)
for {
// Check to make sure the context has expired. This will prevent us from
// looping forever if an error occurs for wait-for-ready RPCs where no data
@ -232,14 +245,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr)
if err != nil {
if done != nil {
doneInfo := balancer.DoneInfo{Err: err}
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
doneInfo.BytesSent = true
}
done(doneInfo)
done(balancer.DoneInfo{Err: err})
done = nil
}
// In the event of any error from NewStream, we never attempted to write
@ -253,15 +259,12 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break
}
// Set callInfo.peer object from stream's context.
if peer, ok := peer.FromContext(s.Context()); ok {
c.peer = peer
}
c.stream = s
cs := &clientStream{
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
codec: c.codec,
cp: cp,
dc: cc.dopts.dc,
comp: comp,
@ -278,29 +281,21 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsCtx: ctx,
statsHandler: cc.dopts.copts.StatsHandler,
}
// Listen on s.Context().Done() to detect cancellation and s.Done() to detect
// normal termination when there is no pending I/O operations on this stream.
go func() {
select {
case <-t.Error():
// Incur transport error, simply exit.
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing)
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
cs.finish(s.Status().Err())
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
cs.closeTransportStream(transport.ContextErr(err))
}
}()
if desc != unaryStreamDesc {
// Listen on cc and stream contexts to cleanup when the user closes the
// ClientConn or cancels the stream context. In all other cases, an error
// should already be injected into the recv buffer by the transport, which
// the client will eventually receive, and then we will cancel the stream's
// context in clientStream.finish.
go func() {
select {
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
case <-ctx.Done():
cs.finish(toRPCErr(s.Context().Err()))
}
}()
}
return cs, nil
}
@ -313,20 +308,22 @@ type clientStream struct {
p *parser
desc *StreamDesc
codec Codec
codec baseCodec
cp Compressor
dc Decompressor
comp encoding.Compressor
decomp encoding.Compressor
decompSet bool
// cancel is only called when RecvMsg() returns non-nil error, which means
// the stream finishes with error or with io.EOF.
cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex
done func(balancer.DoneInfo)
closed bool
sentLast bool // sent an end stream
finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
@ -346,9 +343,8 @@ func (cs *clientStream) Context() context.Context {
func (cs *clientStream) Header() (metadata.MD, error) {
m, err := cs.s.Header()
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
cs.finish(err)
}
return m, err
}
@ -358,6 +354,7 @@ func (cs *clientStream) Trailer() metadata.MD {
}
func (cs *clientStream) SendMsg(m interface{}) (err error) {
// TODO: Check cs.sentLast and error if we already ended the stream.
if cs.tracing {
cs.mu.Lock()
if cs.trInfo.tr != nil {
@ -368,26 +365,18 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
// TODO Investigate how to signal the stats handling party.
// generate error stats if err != nil && err != io.EOF?
defer func() {
if err != nil {
// For non-client-streaming RPCs, we return nil instead of EOF on success
// because the generated code requires it. finish is not called; RecvMsg()
// will call it with the stream's status independently.
if err == io.EOF && !cs.desc.ClientStreams {
err = nil
}
if err != nil && err != io.EOF {
// Call finish for errors generated by this SendMsg call. (Transport
// errors are converted to an io.EOF error below; the real error will be
// returned from RecvMsg eventually in that case.)
cs.finish(err)
}
if err == nil {
return
}
if err == io.EOF {
// Specialize the process for server streaming. SendMsg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
err = nil
}
return
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
}()
var outPayload *stats.OutPayload
if cs.statsHandler != nil {
@ -399,30 +388,36 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil {
return err
}
if cs.c.maxSendMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(data) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
}
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
if !cs.desc.ClientStreams {
cs.sentLast = true
}
return err
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
if err == nil {
if outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
}
return nil
}
return io.EOF
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
defer func() {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)
}
}()
var inPayload *stats.InPayload
if cs.statsHandler != nil {
inPayload = &stats.InPayload{
Client: true,
}
}
if cs.c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
if !cs.decompSet {
// Block until we receive headers containing received message encoding.
if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity {
@ -440,98 +435,67 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
cs.decompSet = true
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
cs.finish(err)
}
}()
if err == nil {
if cs.tracing {
cs.mu.Lock()
if cs.trInfo.tr != nil {
cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
}
cs.mu.Unlock()
}
if inPayload != nil {
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
}
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
if cs.c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err != nil {
if err == io.EOF {
if se := cs.s.Status().Err(); se != nil {
return se
if statusErr := cs.s.Status().Err(); statusErr != nil {
return statusErr
}
cs.finish(err)
return nil
return io.EOF // indicates successful end of stream.
}
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
if cs.tracing {
cs.mu.Lock()
if cs.trInfo.tr != nil {
cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
}
cs.mu.Unlock()
}
if inPayload != nil {
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
}
if cs.desc.ServerStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
return nil
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
if statusErr := cs.s.Status().Err(); statusErr != nil {
return statusErr
}
// Returns io.EOF to indicate the end of the stream.
return
return cs.s.Status().Err() // non-server streaming Recv returns nil on success
}
return toRPCErr(err)
}
func (cs *clientStream) CloseSend() (err error) {
err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
defer func() {
if err != nil {
cs.finish(err)
}
}()
if err == nil || err == io.EOF {
func (cs *clientStream) CloseSend() error {
if cs.sentLast {
return nil
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
}
err = toRPCErr(err)
return
}
func (cs *clientStream) closeTransportStream(err error) {
cs.mu.Lock()
if cs.closed {
cs.mu.Unlock()
return
}
cs.closed = true
cs.mu.Unlock()
cs.t.CloseStream(cs.s, err)
cs.sentLast = true
cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
// We ignore errors from Write and always return nil here. Any error it
// would return would also be returned by a subsequent RecvMsg call, and the
// user is supposed to always finish the stream by calling RecvMsg until it
// returns err != nil.
return nil
}
func (cs *clientStream) finish(err error) {
if err == io.EOF {
// Ending a stream with EOF indicates a success.
err = nil
}
cs.mu.Lock()
defer cs.mu.Unlock()
if cs.finished {
return
}
cs.finished = true
defer func() {
if cs.cancel != nil {
cs.cancel()
}
}()
cs.t.CloseStream(cs.s, err)
for _, o := range cs.opts {
o.after(cs.c)
}
@ -547,18 +511,16 @@ func (cs *clientStream) finish(err error) {
end := &stats.End{
Client: true,
EndTime: time.Now(),
}
if err != io.EOF {
// end.Error is nil if the RPC finished successfully.
end.Error = toRPCErr(err)
Error: err,
}
cs.statsHandler.HandleRPC(cs.statsCtx, end)
}
cs.cancel()
if !cs.tracing {
return
}
if cs.trInfo.tr != nil {
if err == nil || err == io.EOF {
if err == nil {
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
} else {
cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)
@ -593,7 +555,7 @@ type serverStream struct {
t transport.ServerTransport
s *transport.Stream
p *parser
codec Codec
codec baseCodec
cp Compressor
dc Decompressor

View File

@ -705,7 +705,7 @@ func (te *test) clientConn() *grpc.ClientConn {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
if te.customCodec != nil {
opts = append(opts, grpc.WithCodec(te.customCodec))
opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(te.customCodec)))
}
if !te.nonBlockingDial && te.srvAddr != "" {
// Only do a blocking dial if server is up.
@ -925,7 +925,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
@ -1164,10 +1164,22 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) {
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(req); err == nil {
if _, err := stream.Recv(); err == nil {
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
sendStart := time.Now()
for {
if err := stream.Send(req); err == io.EOF {
// stream.Send should eventually send io.EOF
break
} else if err != nil {
// Send should never return a transport-level error.
t.Fatalf("stream.Send(%v) = %v; want <nil or io.EOF>", req, err)
}
if time.Since(sendStart) > 2*time.Second {
t.Fatalf("stream.Send(_) did not return io.EOF after 2s")
}
time.Sleep(time.Millisecond)
}
if _, err := stream.Recv(); err == nil || err == io.EOF {
t.Fatalf("%v.Recv() = _, %v, want _, <non-nil, non-EOF>", stream, err)
}
<-ch
awaitNewConnLogOutput()
@ -1190,7 +1202,9 @@ func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.FullDuplexCall(context.Background()); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if _, err := tc.FullDuplexCall(ctx); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err)
}
done := make(chan struct{})
@ -1808,6 +1822,80 @@ func TestServiceConfigMaxMsgSize(t *testing.T) {
}
}
// Reading from a streaming RPC may fail with context canceled if timeout was
// set by service config (https://github.com/grpc/grpc-go/issues/1818). This
// test makes sure read from streaming RPC doesn't fail in this case.
func TestStreamingRPCWithTimeoutInServiceConfigRecv(t *testing.T) {
te := testServiceConfigSetup(t, tcpClearRREnv)
te.startServer(&testServer{security: tcpClearRREnv.security})
defer te.tearDown()
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
te.resolverScheme = r.Scheme()
te.nonBlockingDial = true
fmt.Println("1")
cc := te.clientConn()
fmt.Println("10")
tc := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{Addr: te.srvAddr}})
r.NewServiceConfig(`{
"methodConfig": [
{
"name": [
{
"service": "grpc.testing.TestService",
"method": "FullDuplexCall"
}
],
"waitForReady": true,
"timeout": "10s"
}
]
}`)
// Make sure service config has been processed by grpc.
for {
if cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall").Timeout != nil {
break
}
time.Sleep(time.Millisecond)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
if err != nil {
t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want <nil>", err)
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
if err != nil {
t.Fatalf("failed to newPayload: %v", err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseParameters: []*testpb.ResponseParameters{{Size: 0}},
Payload: payload,
}
if err := stream.Send(req); err != nil {
t.Fatalf("stream.Send(%v) = %v, want <nil>", req, err)
}
stream.CloseSend()
time.Sleep(time.Second)
// Sleep 1 second before recv to make sure the final status is received
// before the recv.
if _, err := stream.Recv(); err != nil {
t.Fatalf("stream.Recv = _, %v, want _, <nil>", err)
}
// Keep reading to drain the stream.
for {
if _, err := stream.Recv(); err != nil {
break
}
}
}
func TestMaxMsgSizeClientDefault(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
@ -2260,24 +2348,6 @@ func testHealthCheckServingStatus(t *testing.T, e env) {
}
func TestErrorChanNoIO(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
testErrorChanNoIO(t, e)
}
}
func testErrorChanNoIO(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.FullDuplexCall(context.Background()); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
}
func TestEmptyUnaryWithUserAgent(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
@ -2607,6 +2677,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
delete(header, "trailer") // RFC 2616 says server SHOULD (but optional) declare trailers
delete(header, "date") // the Date header is also optional
delete(header, "user-agent")
delete(header, "content-type")
}
if !reflect.DeepEqual(header, testMetadata) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
@ -2723,6 +2794,7 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -2767,6 +2839,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -2810,6 +2883,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <non-nil>", ctx, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -2854,6 +2928,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -2917,6 +2992,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -2975,6 +3051,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
delete(header, "user-agent")
delete(header, "content-type")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@ -3335,6 +3412,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
}
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
@ -3342,6 +3420,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
headerMD, err = stream.Header()
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
@ -3728,22 +3807,24 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.StreamingInputCall(context.Background()); err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, <nil>", err)
}
// Loop until the new max stream setting is effective.
for {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := tc.StreamingInputCall(ctx)
cancel()
if err == nil {
time.Sleep(50 * time.Millisecond)
time.Sleep(5 * time.Millisecond)
continue
}
if status.Code(err) == codes.DeadlineExceeded {
break
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded)
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
var wg sync.WaitGroup
@ -3765,11 +3846,19 @@ func testStreamsQuotaRecovery(t *testing.T, e env) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); status.Code(err) != codes.DeadlineExceeded {
t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
t.Errorf("tc.UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
}()
}
wg.Wait()
cancel()
// A new stream should be allowed after canceling the first one.
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Fatalf("tc.StreamingInputCall(_) = _, %v, want _, %v", err, nil)
}
}
func TestCompressServerHasNoSupport(t *testing.T) {
@ -3807,23 +3896,6 @@ func testCompressServerHasNoSupport(t *testing.T, e env) {
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: 31415,
},
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Unimplemented {
t.Fatalf("%v.Recv() = %v, want error code %s", stream, err, codes.Unimplemented)
}
@ -4107,6 +4179,7 @@ type funcServer struct {
testpb.TestServiceServer
unaryCall func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error)
streamingInputCall func(stream testpb.TestService_StreamingInputCallServer) error
fullDuplexCall func(stream testpb.TestService_FullDuplexCallServer) error
}
func (s *funcServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
@ -4117,6 +4190,10 @@ func (s *funcServer) StreamingInputCall(stream testpb.TestService_StreamingInput
return s.streamingInputCall(stream)
}
func (s *funcServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return s.fullDuplexCall(stream)
}
func TestClientRequestBodyErrorUnexpectedEOF(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
@ -4238,6 +4315,76 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) {
})
}
func TestClientResourceExhaustedCancelFullDuplex(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
if e.httpHandler {
// httpHandler write won't be blocked on flow control window.
continue
}
testClientResourceExhaustedCancelFullDuplex(t, e)
}
}
func testClientResourceExhaustedCancelFullDuplex(t *testing.T, e env) {
te := newTest(t, e)
recvErr := make(chan error, 1)
ts := &funcServer{fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
defer close(recvErr)
_, err := stream.Recv()
if err != nil {
return status.Errorf(codes.Internal, "stream.Recv() got error: %v, want <nil>", err)
}
// create a payload that's larger than the default flow control window.
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 10)
if err != nil {
return err
}
resp := &testpb.StreamingOutputCallResponse{
Payload: payload,
}
ce := make(chan error)
go func() {
var err error
for {
if err = stream.Send(resp); err != nil {
break
}
}
ce <- err
}()
select {
case err = <-ce:
case <-time.After(10 * time.Second):
err = errors.New("10s timeout reached")
}
recvErr <- err
return err
}}
te.startServer(ts)
defer te.tearDown()
// set a low limit on receive message size to error with Resource Exhausted on
// client side when server send a large message.
te.maxClientReceiveMsgSize = newInt(10)
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
req := &testpb.StreamingOutputCallRequest{}
if err := stream.Send(req); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, req, err)
}
if _, err := stream.Recv(); status.Code(err) != codes.ResourceExhausted {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted)
}
err = <-recvErr
if status.Code(err) != codes.Canceled {
t.Fatalf("server got error %v, want error code: %s", err, codes.Canceled)
}
}
type clientTimeoutCreds struct {
timeoutReturned bool
}
@ -4924,6 +5071,36 @@ func TestTapTimeout(t *testing.T) {
t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, <status with Code()=Canceled>", res, err)
}
}
}
func TestClientWriteFailsAfterServerClosesStream(t *testing.T) {
ss := &stubServer{
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
return status.Errorf(codes.Internal, "")
},
}
sopts := []grpc.ServerOption{}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting endpoing server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
stream, err := ss.client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Error while creating stream: %v", err)
}
for {
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err == nil {
time.Sleep(5 * time.Millisecond)
} else if err == io.EOF {
break // Success.
} else {
t.Fatalf("stream.Send(_) = %v, want io.EOF", err)
}
}
}
type windowSizeConfig struct {
@ -5819,3 +5996,47 @@ func TestServeExitsWhenListenerClosed(t *testing.T) {
t.Fatalf("Serve did not return after %v", timeout)
}
}
func TestClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
if e.httpHandler {
continue
}
testClientDoesntDeadlockWhileWritingErrornousLargeMessages(t, e)
}
}
func testClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
smallSize := 1024
te.maxServerReceiveMsgSize = &smallSize
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 1048576)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
Payload: payload,
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
defer cancel()
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.ResourceExhausted {
t.Errorf("TestService/UnaryCall(_,_) = _. %v, want code: %s", err, codes.ResourceExhausted)
return
}
}
}()
}
wg.Wait()
}

View File

@ -40,20 +40,24 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)
// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
}
if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method")
}
if !validContentType(r.Header.Get("Content-Type")) {
contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := contentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
}
if _, ok := w.(http.Flusher); !ok {
@ -64,10 +68,13 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
}
if v := r.Header.Get("grpc-timeout"); v != "" {
@ -79,7 +86,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
st.timeout = to
}
var metakv []string
metakv := []string{"content-type", contentType}
if r.Host != "" {
metakv = append(metakv, ":authority", r.Host)
}
@ -126,6 +133,14 @@ type serverHandlerTransport struct {
// block concurrent WriteStatus calls
// e.g. grpc/(*serverStream).SendMsg/RecvMsg
writeStatusMu sync.Mutex
// we just mirror the request content-type
contentType string
// we store both contentType and contentSubtype so we don't keep recreating them
// TODO make sure this is consistent across handler_server and http2_server
contentSubtype string
stats stats.Handler
}
func (ht *serverHandlerTransport) Close() error {
@ -219,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
})
if err == nil { // transport has not been closed
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
ht.Close()
close(ht.writes)
}
@ -235,7 +253,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", "application/grpc")
h.Set("Content-Type", ht.contentType)
// Predeclare trailers we'll set later in WriteStatus (after the body).
// This is a SHOULD in the HTTP RFC, and the way you add (known)
@ -263,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return ht.do(func() {
err := ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
@ -279,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})
if err == nil {
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
}
}
return err
}
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
@ -313,13 +338,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req
s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
@ -330,6 +356,15 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s)
if ht.stats != nil {
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
windowHandler: func(int) {},

View File

@ -199,9 +199,10 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
},
check: func(ht *serverHandlerTransport, tt *testCase) error {
want := metadata.MD{
"meta-bar": {"bar-val1", "bar-val2"},
"user-agent": {"x/y a/b"},
"meta-foo": {"foo-val"},
"meta-bar": {"bar-val1", "bar-val2"},
"user-agent": {"x/y a/b"},
"meta-foo": {"foo-val"},
"content-type": {"application/grpc"},
}
if !reflect.DeepEqual(ht.headerMD, want) {
@ -217,7 +218,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
if tt.modrw != nil {
rw = tt.modrw(rw)
}
got, gotErr := NewServerHandlerTransport(rw, tt.req)
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
continue
@ -271,7 +272,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}
@ -356,7 +357,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -314,15 +314,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
fc: &inFlow{limit: uint32(t.initialWindowSize)},
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}),
id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
fc: &inFlow{limit: uint32(t.initialWindowSize)},
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}),
contentSubtype: callHdr.ContentSubtype,
}
t.nextID += 2
s.requestRead = func(n int) {
@ -380,7 +381,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
for _, c := range t.creds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
if _, ok := status.FromError(err); ok {
return nil, err
}
return nil, streamErrorf(codes.Unauthenticated, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2.
@ -434,7 +439,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
@ -459,7 +464,22 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if b := stats.OutgoingTrace(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
}
if md, ok := metadata.FromOutgoingContext(ctx); ok {
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
var k string
for _, vv := range added {
for i, v := range vv {
if i%2 == 0 {
k = v
continue
}
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)})
}
}
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
@ -576,7 +596,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
}
s.state = streamDone
s.mu.Unlock()
if _, ok := err.(StreamError); ok {
if err != nil && !rstStream {
rstStream = true
rstError = http2.ErrCodeCancel
}
@ -645,6 +665,8 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.done:
return io.EOF
case <-t.ctx.Done():
return ErrConnClosing
default:
@ -694,6 +716,8 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
}
ltq, _, err := t.localSendQuota.get(size, s.waiters)
if err != nil {
// Add the acquired quota back to transport.
t.sendQuotaPool.add(tq)
return err
}
// even if ltq is smaller than size we don't adjust size since
@ -1110,16 +1134,17 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}()
s.mu.Lock()
if !endStream {
s.recvCompress = state.encoding
}
if !s.headerDone {
if !endStream && len(state.mdata) > 0 {
s.header = state.mdata
if !endStream {
// Headers frame is not actually a trailers-only frame.
isHeader = true
s.recvCompress = state.encoding
if len(state.mdata) > 0 {
s.header = state.mdata
}
}
close(s.headerChan)
s.headerDone = true
isHeader = true
}
if !endStream || s.state == streamDone {
s.mu.Unlock()

View File

@ -281,12 +281,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
buf := newRecvBuffer()
s := &Stream{
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
contentSubtype: state.contentSubtype,
}
if frame.StreamEnded() {
@ -730,7 +731,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
if s.sendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
@ -749,9 +750,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
endStream: false,
})
if t.stats != nil {
outHeader := &stats.OutHeader{
//WireLength: // TODO(mmukhi): Revisit this later, if needed.
}
// Note: WireLength is not set in outHeader.
// TODO(mmukhi): Revisit this later, if needed.
outHeader := &stats.OutHeader{}
t.stats.HandleRPC(s.Context(), outHeader)
}
return nil
@ -792,7 +793,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
if !headersSent {
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
}
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
@ -842,10 +843,6 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
var writeHeaderFrame bool
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return streamErrorf(codes.Unknown, "the stream has been done")
}
if !s.headerOk {
writeHeaderFrame = true
}
@ -891,6 +888,8 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
}
ltq, _, err := t.localSendQuota.get(size, s.waiters)
if err != nil {
// Add the acquired quota back to transport.
t.sendQuotaPool.add(tq)
return err
}
// even if ltq is smaller than size we don't adjust size since,

View File

@ -46,6 +46,12 @@ const (
// http2IOBufSize specifies the buffer size for sending frames.
defaultWriteBufSize = 32 * 1024
defaultReadBufSize = 32 * 1024
// baseContentType is the base content-type for gRPC. This is a valid
// content-type on it's own, but can also include a content-subtype such as
// "proto" as a suffix after "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
baseContentType = "application/grpc"
)
var (
@ -111,9 +117,10 @@ type decodeState struct {
timeout time.Duration
method string
// key-value metadata map from the peer.
mdata map[string][]string
statsTags []byte
statsTrace []byte
mdata map[string][]string
statsTags []byte
statsTrace []byte
contentSubtype string
}
// isReservedHeader checks whether hdr belongs to HTTP2 headers
@ -149,17 +156,44 @@ func isWhitelistedPseudoHeader(hdr string) bool {
}
}
func validContentType(t string) bool {
e := "application/grpc"
if !strings.HasPrefix(t, e) {
return false
// contentSubtype returns the content-subtype for the given content-type. The
// given content-type must be a valid content-type that starts with
// "application/grpc". A content-subtype will follow "application/grpc" after a
// "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If contentType is not a valid content-type for gRPC, the boolean
// will be false, otherwise true. If content-type == "application/grpc",
// "application/grpc+", or "application/grpc;", the boolean will be true,
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func contentSubtype(contentType string) (string, bool) {
if contentType == baseContentType {
return "", true
}
// Support variations on the content-type
// (e.g. "application/grpc+blah", "application/grpc;blah").
if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' {
return false
if !strings.HasPrefix(contentType, baseContentType) {
return "", false
}
return true
// guaranteed since != baseContentType and has baseContentType prefix
switch contentType[len(baseContentType)] {
case '+', ';':
// this will return true for "application/grpc+" or "application/grpc;"
// which the previous validContentType function tested to be valid, so we
// just say that no content-subtype is specified in this case
return contentType[len(baseContentType)+1:], true
default:
return "", false
}
}
// contentSubtype is assumed to be lowercase
func contentType(contentSubtype string) string {
if contentSubtype == "" {
return baseContentType
}
return baseContentType + "+" + contentSubtype
}
func (d *decodeState) status() *status.Status {
@ -247,9 +281,16 @@ func (d *decodeState) addMetadata(k, v string) {
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name {
case "content-type":
if !validContentType(f.Value) {
contentSubtype, validContentType := contentSubtype(f.Value)
if !validContentType {
return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)
}
d.contentSubtype = contentSubtype
// TODO: do we want to propagate the whole content-type in the metadata,
// or come up with a way to just propagate the content-subtype if it was set?
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
// in the metadata?
d.addMetadata(f.Name, f.Value)
case "grpc-encoding":
d.encoding = f.Value
case "grpc-status":

View File

@ -72,24 +72,25 @@ func TestTimeoutDecode(t *testing.T) {
}
}
func TestValidContentType(t *testing.T) {
func TestContentSubtype(t *testing.T) {
tests := []struct {
h string
want bool
contentType string
want string
wantValid bool
}{
{"application/grpc", true},
{"application/grpc+", true},
{"application/grpc+blah", true},
{"application/grpc;", true},
{"application/grpc;blah", true},
{"application/grpcd", false},
{"application/grpd", false},
{"application/grp", false},
{"application/grpc", "", true},
{"application/grpc+", "", true},
{"application/grpc+blah", "blah", true},
{"application/grpc;", "", true},
{"application/grpc;blah", "blah", true},
{"application/grpcd", "", false},
{"application/grpd", "", false},
{"application/grp", "", false},
}
for _, tt := range tests {
got := validContentType(tt.h)
if got != tt.want {
t.Errorf("validContentType(%q) = %v; want %v", tt.h, got, tt.want)
got, gotValid := contentSubtype(tt.contentType)
if got != tt.want || gotValid != tt.wantValid {
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
}
}
}

View File

@ -246,6 +246,10 @@ type Stream struct {
bytesReceived bool // indicates whether any bytes have been received on this stream
unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
func (s *Stream) waitOnHeader() error {
@ -321,6 +325,15 @@ func (s *Stream) ServerTransport() ServerTransport {
return s.st
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
@ -553,6 +566,14 @@ type CallHdr struct {
// for performance purposes.
// If it's false, new stream will never be flushed.
Flush bool
// ContentSubtype specifies the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". The value of ContentSubtype must be all
// lowercase, otherwise the behavior is undefined. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
ContentSubtype string
}
// ClientTransport is the common interface for all gRPC client-side transport
@ -676,13 +697,13 @@ func (e ConnectionError) Origin() error {
var (
// ErrConnClosing indicates that the transport is closing.
ErrConnClosing = connectionErrorf(true, nil, "transport is closing")
// errStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs.
// TODO: delete this error; it is no longer necessary.
errStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
// errStreamDrain indicates that the stream is rejected because the
// connection is draining. This could be caused by goaway or balancer
// removing the address.
errStreamDrain = streamErrorf(codes.Unavailable, "the connection is draining")
// StatusGoAway indicates that the server sent a GOAWAY that included this
// stream's ID in unprocessed RPCs.
statusGoAway = status.New(codes.Unavailable, "the server stopped accepting new RPCs")
statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection")
)
// TODO: See if we can replace StreamError with status package errors.