rebase: bump google.golang.org/grpc from 1.65.0 to 1.66.0

Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.65.0 to 1.66.0.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.65.0...v1.66.0)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot] 2024-09-02 20:06:42 +00:00 committed by mergify[bot]
parent 89da94cfd0
commit 56cf915dff
59 changed files with 2807 additions and 1294 deletions

4
go.mod
View File

@ -28,7 +28,7 @@ require (
golang.org/x/crypto v0.26.0 golang.org/x/crypto v0.26.0
golang.org/x/net v0.28.0 golang.org/x/net v0.28.0
golang.org/x/sys v0.24.0 golang.org/x/sys v0.24.0
google.golang.org/grpc v1.65.0 google.golang.org/grpc v1.66.0
google.golang.org/protobuf v1.34.2 google.golang.org/protobuf v1.34.2
// //
// when updating k8s.io/kubernetes, make sure to update the replace section too // when updating k8s.io/kubernetes, make sure to update the replace section too
@ -166,7 +166,7 @@ require (
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.24.0 // indirect golang.org/x/tools v0.24.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect

6
go.sum
View File

@ -3251,8 +3251,9 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2/go.
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE= google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8/go.mod h1:vPrPUTsDCYxXWjP7clS81mZ6/803D8K4iM9Ma27VKas= google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8/go.mod h1:vPrPUTsDCYxXWjP7clS81mZ6/803D8K4iM9Ma27VKas=
google.golang.org/genproto/googleapis/api v0.0.0-20240520151616-dc85e6b867a5/go.mod h1:RGnPtTG7r4i8sPlNyDeikXF99hMM+hN6QMm4ooG9g2g= google.golang.org/genproto/googleapis/api v0.0.0-20240520151616-dc85e6b867a5/go.mod h1:RGnPtTG7r4i8sPlNyDeikXF99hMM+hN6QMm4ooG9g2g=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU=
google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 h1:+rdxYoE3E5htTEWIe15GlN6IfvbURM//Jt0mmkmm6ZU=
google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117/go.mod h1:OimBR/bc1wPO9iV4NC2bpyjy3VnAwZh5EBPQdtaE5oo=
google.golang.org/genproto/googleapis/bytestream v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:ylj+BE99M198VPbBh6A8d9n3w8fChvyLK3wwBOjXBFA= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:ylj+BE99M198VPbBh6A8d9n3w8fChvyLK3wwBOjXBFA=
google.golang.org/genproto/googleapis/bytestream v0.0.0-20230807174057-1744710a1577/go.mod h1:NjCQG/D8JandXxM57PZbAJL1DCNL6EypA0vPPwfsc7c= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230807174057-1744710a1577/go.mod h1:NjCQG/D8JandXxM57PZbAJL1DCNL6EypA0vPPwfsc7c=
google.golang.org/genproto/googleapis/bytestream v0.0.0-20231030173426-d783a09b4405/go.mod h1:GRUCuLdzVqZte8+Dl/D4N25yLzcGqqWaYkeVOwulFqw= google.golang.org/genproto/googleapis/bytestream v0.0.0-20231030173426-d783a09b4405/go.mod h1:GRUCuLdzVqZte8+Dl/D4N25yLzcGqqWaYkeVOwulFqw=
@ -3361,8 +3362,9 @@ google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJai
google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c=
google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

View File

@ -9,21 +9,28 @@ for general contribution guidelines.
## Maintainers (in alphabetical order) ## Maintainers (in alphabetical order)
- [aranjans](https://github.com/aranjans), Google LLC
- [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc. - [atollena](https://github.com/atollena), Datadog, Inc.
- [cesarghali](https://github.com/cesarghali), Google LLC
- [dfawley](https://github.com/dfawley), Google LLC - [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC - [easwars](https://github.com/easwars), Google LLC
- [menghanl](https://github.com/menghanl), Google LLC - [erm-g](https://github.com/erm-g), Google LLC
- [srini100](https://github.com/srini100), Google LLC - [gtcooke94](https://github.com/gtcooke94), Google LLC
- [purnesh42h](https://github.com/purnesh42h), Google LLC
- [zasweq](https://github.com/zasweq), Google LLC
## Emeritus Maintainers (in alphabetical order) ## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez), Google LLC - [adelez](https://github.com/adelez)
- [canguler](https://github.com/canguler), Google LLC - [canguler](https://github.com/canguler)
- [iamqizhao](https://github.com/iamqizhao), Google LLC - [cesarghali](https://github.com/cesarghali)
- [jadekler](https://github.com/jadekler), Google LLC - [iamqizhao](https://github.com/iamqizhao)
- [jtattermusch](https://github.com/jtattermusch), Google LLC - [jeanbza](https://github.com/jeanbza)
- [lyuxuan](https://github.com/lyuxuan), Google LLC - [jtattermusch](https://github.com/jtattermusch)
- [makmukhi](https://github.com/makmukhi), Google LLC - [lyuxuan](https://github.com/lyuxuan)
- [matt-kwong](https://github.com/matt-kwong), Google LLC - [makmukhi](https://github.com/makmukhi)
- [nicolasnoble](https://github.com/nicolasnoble), Google LLC - [matt-kwong](https://github.com/matt-kwong)
- [yongni](https://github.com/yongni), Google LLC - [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble)
- [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni)

View File

@ -1,3 +1,3 @@
# Security Policy # Security Policy
For information on gRPC Security Policy and reporting potentional security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md). For information on gRPC Security Policy and reporting potential security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).

View File

@ -39,7 +39,7 @@ type Config struct {
MaxDelay time.Duration MaxDelay time.Duration
} }
// DefaultConfig is a backoff configuration with the default values specfied // DefaultConfig is a backoff configuration with the default values specified
// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. // at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
// //
// This should be useful for callers who want to configure backoff with // This should be useful for callers who want to configure backoff with

View File

@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/channelz" "google.golang.org/grpc/channelz"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -72,8 +73,21 @@ func unregisterForTesting(name string) {
delete(m, name) delete(m, name)
} }
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
func init() { func init() {
internal.BalancerUnregister = unregisterForTesting internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
} }
// Get returns the resolver builder registered with the given name. // Get returns the resolver builder registered with the given name.
@ -243,6 +257,10 @@ type BuildOptions struct {
// same resolver.Target as passed to the resolver. See the documentation for // same resolver.Target as passed to the resolver. See the documentation for
// the resolver.Target type for details about what it contains. // the resolver.Target type for details about what it contains.
Target resolver.Target Target resolver.Target
// MetricsRecorder is the metrics recorder that balancers can use to record
// metrics. Balancer implementations which do not register metrics on
// metrics registry and record on them can ignore this field.
MetricsRecorder estats.MetricsRecorder
} }
// Builder creates a balancer. // Builder creates a balancer.
@ -410,6 +428,9 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure, // ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil. // describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
} }
// ClientConnState describes the state of a ClientConn relevant to the // ClientConnState describes the state of a ClientConn relevant to the

View File

@ -155,7 +155,7 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
// Endpoints not set, process addresses until we migrate resolver // Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted // emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted // addresses with endpoints, however some balancers such as weighted
// target do not forwarrd the corresponding correct endpoints down/split // target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints // endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional. // down, can delete this else conditional.
addrs = state.ResolverState.Addresses addrs = state.ResolverState.Addresses

View File

@ -25,12 +25,15 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// ccBalancerWrapper sits between the ClientConn and the Balancer. // ccBalancerWrapper sits between the ClientConn and the Balancer.
// //
// ccBalancerWrapper implements methods corresponding to the ones on the // ccBalancerWrapper implements methods corresponding to the ones on the
@ -79,6 +82,7 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
CustomUserAgent: cc.dopts.copts.UserAgent, CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParent: cc.channelz, ChannelzParent: cc.channelz,
Target: cc.parsedTarget, Target: cc.parsedTarget,
MetricsRecorder: cc.metricsRecorderList,
}, },
serializer: grpcsync.NewCallbackSerializer(ctx), serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel, serializerCancel: cancel,
@ -92,7 +96,7 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
// it is safe to call into the balancer here. // it is safe to call into the balancer here.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error { func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
errCh := make(chan error) errCh := make(chan error)
ok := ccb.serializer.Schedule(func(ctx context.Context) { uccs := func(ctx context.Context) {
defer close(errCh) defer close(errCh)
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
@ -107,17 +111,23 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
logger.Infof("error from balancer.UpdateClientConnState: %v", err) logger.Infof("error from balancer.UpdateClientConnState: %v", err)
} }
errCh <- err errCh <- err
})
if !ok {
return nil
} }
onFailure := func() { close(errCh) }
// UpdateClientConnState can race with Close, and when the latter wins, the
// serializer is closed, and the attempt to schedule the callback will fail.
// It is acceptable to ignore this failure. But since we want to handle the
// state update in a blocking fashion (when we successfully schedule the
// callback), we have to use the ScheduleOr method and not the MaybeSchedule
// method on the serializer.
ccb.serializer.ScheduleOr(uccs, onFailure)
return <-errCh return <-errCh
} }
// resolverError is invoked by grpc to push a resolver error to the underlying // resolverError is invoked by grpc to push a resolver error to the underlying
// balancer. The call to the balancer is executed from the serializer. // balancer. The call to the balancer is executed from the serializer.
func (ccb *ccBalancerWrapper) resolverError(err error) { func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.serializer.Schedule(func(ctx context.Context) { ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
} }
@ -133,7 +143,7 @@ func (ccb *ccBalancerWrapper) close() {
ccb.closed = true ccb.closed = true
ccb.mu.Unlock() ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing") channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing")
ccb.serializer.Schedule(func(context.Context) { ccb.serializer.TrySchedule(func(context.Context) {
if ccb.balancer == nil { if ccb.balancer == nil {
return return
} }
@ -145,7 +155,7 @@ func (ccb *ccBalancerWrapper) close() {
// exitIdle invokes the balancer's exitIdle method in the serializer. // exitIdle invokes the balancer's exitIdle method in the serializer.
func (ccb *ccBalancerWrapper) exitIdle() { func (ccb *ccBalancerWrapper) exitIdle() {
ccb.serializer.Schedule(func(ctx context.Context) { ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
} }
@ -252,15 +262,29 @@ type acBalancerWrapper struct {
// updateState is invoked by grpc to push a subConn state update to the // updateState is invoked by grpc to push a subConn state update to the
// underlying balancer. // underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) { acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil { if ctx.Err() != nil || acbw.ccb.balancer == nil {
return return
} }
// Even though it is optional for balancers, gracefulswitch ensures // Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil. // opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed. // TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
acbw.stateListener(scs)
acbw.ac.mu.Lock()
defer acbw.ac.mu.Unlock()
if s == connectivity.Ready {
// When changing states to READY, reset stateReadyChan. Wait until
// after we notify the LB policy's listener(s) in order to prevent
// ac.getTransport() from unblocking before the LB policy starts
// tracking the subchannel as READY.
close(acbw.ac.stateReadyChan)
acbw.ac.stateReadyChan = make(chan struct{})
}
}) })
} }

View File

@ -19,7 +19,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.34.1 // protoc-gen-go v1.34.1
// protoc v4.25.2 // protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto // source: grpc/binlog/v1/binarylog.proto
package grpc_binarylog_v1 package grpc_binarylog_v1

View File

@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -39,6 +40,7 @@ import (
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle" "google.golang.org/grpc/internal/idle"
iresolver "google.golang.org/grpc/internal/resolver" iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stats"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -194,8 +196,11 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz) cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers) cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc. cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout) cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
return cc, nil return cc, nil
} }
@ -590,13 +595,14 @@ type ClientConn struct {
cancel context.CancelFunc // Cancelled on close. cancel context.CancelFunc // Cancelled on close.
// The following are initialized at dial time, and are read-only after that. // The following are initialized at dial time, and are read-only after that.
target string // User's dial target. target string // User's dial target.
parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder(). parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder().
authority string // See initAuthority(). authority string // See initAuthority().
dopts dialOptions // Default and user specified dial options. dopts dialOptions // Default and user specified dial options.
channelz *channelz.Channel // Channelz object. channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder(). resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager idlenessMgr *idle.Manager
metricsRecorderList *stats.MetricsRecorderList
// The following provide their own synchronization, and therefore don't // The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them. // require cc.mu to be held to access them.
@ -626,11 +632,6 @@ type ClientConn struct {
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
// ctx expires. A true value is returned in former case and false in latter. // ctx expires. A true value is returned in former case and false in latter.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool { func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
ch := cc.csMgr.getNotifyChan() ch := cc.csMgr.getNotifyChan()
if cc.csMgr.getState() != sourceState { if cc.csMgr.getState() != sourceState {
@ -645,11 +646,6 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec
} }
// GetState returns the connectivity.State of ClientConn. // GetState returns the connectivity.State of ClientConn.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (cc *ClientConn) GetState() connectivity.State { func (cc *ClientConn) GetState() connectivity.State {
return cc.csMgr.getState() return cc.csMgr.getState()
} }
@ -812,17 +808,11 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
cc.csMgr.updateState(connectivity.TransientFailure) cc.csMgr.updateState(connectivity.TransientFailure)
} }
// Makes a copy of the input addresses slice and clears out the balancer // Makes a copy of the input addresses slice. Addresses are passed during
// attributes field. Addresses are passed during subconn creation and address // subconn creation and address update operations.
// update operations. In both cases, we will clear the balancer attributes by func copyAddresses(in []resolver.Address) []resolver.Address {
// calling this function, and therefore we will be able to use the Equal method
// provided by the resolver.Address type for comparison.
func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address {
out := make([]resolver.Address, len(in)) out := make([]resolver.Address, len(in))
for i := range in { copy(out, in)
out[i] = in[i]
out[i].BalancerAttributes = nil
}
return out return out
} }
@ -835,14 +825,14 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
} }
ac := &addrConn{ ac := &addrConn{
state: connectivity.Idle, state: connectivity.Idle,
cc: cc, cc: cc,
addrs: copyAddressesWithoutBalancerAttributes(addrs), addrs: copyAddresses(addrs),
scopts: opts, scopts: opts,
dopts: cc.dopts, dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""), channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}), resetBackoff: make(chan struct{}),
stateChan: make(chan struct{}), stateReadyChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx) ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if // Start with our address set to the first address; this may be updated if
@ -918,28 +908,29 @@ func (ac *addrConn) connect() error {
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
ac.mu.Unlock()
ac.resetTransport() ac.resetTransportAndUnlock()
return nil return nil
} }
func equalAddresses(a, b []resolver.Address) bool { // equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
if len(a) != len(b) { // This is different from the Equal method on the resolver.Address type which
return false // considers all fields to determine equality. Here, we only consider fields
} // that are meaningful to the subConn.
for i, v := range a { func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
if !v.Equal(b[i]) { return a.Addr == b.Addr && a.ServerName == b.ServerName &&
return false a.Attributes.Equal(b.Attributes) &&
} a.Metadata == b.Metadata
} }
return true
func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
return slices.EqualFunc(a, b, func(a, b resolver.Address) bool { return equalAddressIgnoringBalAttributes(&a, &b) })
} }
// updateAddrs updates ac.addrs with the new addresses list and handles active // updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts. // connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) { func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
addrs = copyAddressesWithoutBalancerAttributes(addrs) addrs = copyAddresses(addrs)
limit := len(addrs) limit := len(addrs)
if limit > 5 { if limit > 5 {
limit = 5 limit = 5
@ -947,7 +938,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit]) channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
ac.mu.Lock() ac.mu.Lock()
if equalAddresses(ac.addrs, addrs) { if equalAddressesIgnoringBalAttributes(ac.addrs, addrs) {
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -966,7 +957,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
// Try to find the connected address. // Try to find the connected address.
for _, a := range addrs { for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a) a.ServerName = ac.cc.getServerName(a)
if a.Equal(ac.curAddr) { if equalAddressIgnoringBalAttributes(&a, &ac.curAddr) {
// We are connected to a valid address, so do nothing but // We are connected to a valid address, so do nothing but
// update the addresses. // update the addresses.
ac.mu.Unlock() ac.mu.Unlock()
@ -992,11 +983,9 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.updateConnectivityState(connectivity.Idle, nil) ac.updateConnectivityState(connectivity.Idle, nil)
} }
ac.mu.Unlock()
// Since we were connecting/connected, we should start a new connection // Since we were connecting/connected, we should start a new connection
// attempt. // attempt.
go ac.resetTransport() go ac.resetTransportAndUnlock()
} }
// getServerName determines the serverName to be used in the connection // getServerName determines the serverName to be used in the connection
@ -1190,8 +1179,8 @@ type addrConn struct {
addrs []resolver.Address // All addresses that the resolver resolved to. addrs []resolver.Address // All addresses that the resolver resolved to.
// Use updateConnectivityState for updating addrConn's connectivity state. // Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State state connectivity.State
stateChan chan struct{} // closed and recreated on every state change. stateReadyChan chan struct{} // closed and recreated on every READY state change.
backoffIdx int // Needs to be stateful for resetConnectBackoff. backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{} resetBackoff chan struct{}
@ -1204,9 +1193,6 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s { if ac.state == s {
return return
} }
// When changing states, reset the state change channel.
close(ac.stateChan)
ac.stateChan = make(chan struct{})
ac.state = s ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s) ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil { if lastErr == nil {
@ -1214,7 +1200,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else { } else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
} }
ac.acbw.updateState(s, lastErr) ac.acbw.updateState(s, ac.curAddr, lastErr)
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon
@ -1231,8 +1217,10 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
} }
} }
func (ac *addrConn) resetTransport() { // resetTransportAndUnlock unconditionally connects the addrConn.
ac.mu.Lock() //
// ac.mu must be held by the caller, and this function will guarantee it is released.
func (ac *addrConn) resetTransportAndUnlock() {
acCtx := ac.ctx acCtx := ac.ctx
if acCtx.Err() != nil { if acCtx.Err() != nil {
ac.mu.Unlock() ac.mu.Unlock()
@ -1522,7 +1510,7 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) { func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil { for ctx.Err() == nil {
ac.mu.Lock() ac.mu.Lock()
t, state, sc := ac.transport, ac.state, ac.stateChan t, state, sc := ac.transport, ac.state, ac.stateReadyChan
ac.mu.Unlock() ac.mu.Unlock()
if state == connectivity.Ready { if state == connectivity.Ready {
return t, nil return t, nil
@ -1585,7 +1573,7 @@ func (ac *addrConn) tearDown(err error) {
} else { } else {
// Hard close the transport when the channel is entering idle or is // Hard close the transport when the channel is entering idle or is
// being shutdown. In the case where the channel is being shutdown, // being shutdown. In the case where the channel is being shutdown,
// closing of transports is also taken care of by cancelation of cc.ctx. // closing of transports is also taken care of by cancellation of cc.ctx.
// But in the case where the channel is entering idle, we need to // But in the case where the channel is entering idle, we need to
// explicitly close the transports here. Instead of distinguishing // explicitly close the transports here. Instead of distinguishing
// between these two cases, it is simpler to close the transport // between these two cases, it is simpler to close the transport

View File

@ -21,18 +21,73 @@ package grpc
import ( import (
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto" _ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
"google.golang.org/grpc/mem"
) )
// baseCodec contains the functionality of both Codec and encoding.Codec, but // baseCodec captures the new encoding.CodecV2 interface without the Name
// omits the name/string, which vary between the two and are not needed for // function, allowing it to be implemented by older Codec and encoding.Codec
// anything besides the registry in the encoding package. // implementations. The omitted Name function is only needed for the register in
// the encoding package and is not part of the core functionality.
type baseCodec interface { type baseCodec interface {
Marshal(v any) ([]byte, error) Marshal(v any) (mem.BufferSlice, error)
Unmarshal(data []byte, v any) error Unmarshal(data mem.BufferSlice, v any) error
} }
var _ baseCodec = Codec(nil) // getCodec returns an encoding.CodecV2 for the codec of the given name (if
var _ baseCodec = encoding.Codec(nil) // registered). Initially checks the V2 registry with encoding.GetCodecV2 and
// returns the V2 codec if it is registered. Otherwise, it checks the V1 registry
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
// to turn it into an encoding.CodecV2. Returns nil otherwise.
func getCodec(name string) encoding.CodecV2 {
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
return newCodecV1Bridge(codecV1)
}
return encoding.GetCodecV2(name)
}
func newCodecV0Bridge(c Codec) baseCodec {
return codecV0Bridge{codec: c}
}
func newCodecV1Bridge(c encoding.Codec) encoding.CodecV2 {
return codecV1Bridge{
codecV0Bridge: codecV0Bridge{codec: c},
name: c.Name(),
}
}
var _ baseCodec = codecV0Bridge{}
type codecV0Bridge struct {
codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
}
func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) {
data, err := c.codec.Marshal(v)
if err != nil {
return nil, err
}
return mem.BufferSlice{mem.NewBuffer(&data, nil)}, nil
}
func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) {
return c.codec.Unmarshal(data.Materialize(), v)
}
var _ encoding.CodecV2 = codecV1Bridge{}
type codecV1Bridge struct {
codecV0Bridge
name string
}
func (c codecV1Bridge) Name() string {
return c.name
}
// Codec defines the interface gRPC uses to encode and decode messages. // Codec defines the interface gRPC uses to encode and decode messages.
// Note that implementations of this interface must be thread safe; // Note that implementations of this interface must be thread safe;

View File

@ -33,6 +33,7 @@ import (
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
) )
@ -60,7 +61,7 @@ func init() {
internal.WithBinaryLogger = withBinaryLogger internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
internal.WithRecvBufferPool = withRecvBufferPool internal.WithBufferPool = withBufferPool
} }
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
@ -92,7 +93,6 @@ type dialOptions struct {
defaultServiceConfigRawJSON *string defaultServiceConfigRawJSON *string
resolvers []resolver.Builder resolvers []resolver.Builder
idleTimeout time.Duration idleTimeout time.Duration
recvBufferPool SharedBufferPool
defaultScheme string defaultScheme string
maxCallAttempts int maxCallAttempts int
} }
@ -677,11 +677,11 @@ func defaultDialOptions() dialOptions {
WriteBufferSize: defaultWriteBufSize, WriteBufferSize: defaultWriteBufSize,
UseProxy: true, UseProxy: true,
UserAgent: grpcUA, UserAgent: grpcUA,
BufferPool: mem.DefaultBufferPool(),
}, },
bs: internalbackoff.DefaultExponential, bs: internalbackoff.DefaultExponential,
healthCheckFunc: internal.HealthCheckFunc, healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute, idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
defaultScheme: "dns", defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts, maxCallAttempts: defaultMaxCallAttempts,
} }
@ -758,25 +758,8 @@ func WithMaxCallAttempts(n int) DialOption {
}) })
} }
// WithRecvBufferPool returns a DialOption that configures the ClientConn func withBufferPool(bufferPool mem.BufferPool) DialOption {
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// v1.60.0 or later.
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return withRecvBufferPool(bufferPool)
}
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool o.copts.BufferPool = bufferPool
}) })
} }

View File

@ -16,7 +16,7 @@
* *
*/ */
//go:generate ./regenerate.sh //go:generate ./scripts/regenerate.sh
/* /*
Package grpc implements an RPC system called gRPC. Package grpc implements an RPC system called gRPC.

View File

@ -94,7 +94,7 @@ type Codec interface {
Name() string Name() string
} }
var registeredCodecs = make(map[string]Codec) var registeredCodecs = make(map[string]any)
// RegisterCodec registers the provided Codec for use with all gRPC clients and // RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers. // servers.
@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
// //
// The content-subtype is expected to be lowercase. // The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec { func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype] c, _ := registeredCodecs[contentSubtype].(Codec)
return c
} }

81
vendor/google.golang.org/grpc/encoding/encoding_v2.go generated vendored Normal file
View File

@ -0,0 +1,81 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package encoding
import (
"strings"
"google.golang.org/grpc/mem"
)
// CodecV2 defines the interface gRPC uses to encode and decode messages. Note
// that implementations of this interface must be thread safe; a CodecV2's
// methods can be called from concurrent goroutines.
type CodecV2 interface {
// Marshal returns the wire format of v. The buffers in the returned
// [mem.BufferSlice] must have at least one reference each, which will be freed
// by gRPC when they are no longer needed.
Marshal(v any) (out mem.BufferSlice, err error)
// Unmarshal parses the wire format into v. Note that data will be freed as soon
// as this function returns. If the codec wishes to guarantee access to the data
// after this function, it must take its own reference that it frees when it is
// no longer needed.
Unmarshal(data mem.BufferSlice, v any) error
// Name returns the name of the Codec implementation. The returned string
// will be used as part of content type in transmission. The result must be
// static; the result cannot change between calls.
Name() string
}
// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
// servers.
//
// The CodecV2 will be stored and looked up by result of its Name() method, which
// should match the content-subtype of the encoding handled by the CodecV2. This
// is case-insensitive, and is stored and looked up as lowercase. If the
// result of calling Name() is an empty string, RegisterCodecV2 will panic. See
// Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If both a Codec and CodecV2 are registered with the same name, the CodecV2
// will be used.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Codecs are
// registered with the same name, the one registered last will take effect.
func RegisterCodecV2(codec CodecV2) {
if codec == nil {
panic("cannot register a nil CodecV2")
}
if codec.Name() == "" {
panic("cannot register CodecV2 with empty string result for Name()")
}
contentSubtype := strings.ToLower(codec.Name())
registeredCodecs[contentSubtype] = codec
}
// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodecV2(contentSubtype string) CodecV2 {
c, _ := registeredCodecs[contentSubtype].(CodecV2)
return c
}

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2018 gRPC authors. * Copyright 2024 gRPC authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt" "google.golang.org/protobuf/protoadapt"
) )
@ -32,28 +33,51 @@ import (
const Name = "proto" const Name = "proto"
func init() { func init() {
encoding.RegisterCodec(codec{}) encoding.RegisterCodecV2(&codecV2{})
} }
// codec is a Codec implementation with protobuf. It is the default codec for gRPC. // codec is a CodecV2 implementation with protobuf. It is the default codec for
type codec struct{} // gRPC.
type codecV2 struct{}
func (codec) Marshal(v any) ([]byte, error) { func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v) vv := messageV2Of(v)
if vv == nil { if vv == nil {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
} }
return proto.Marshal(vv) size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err
}
data = append(data, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
}
data = append(data, mem.NewBuffer(buf, pool))
}
return data, nil
} }
func (codec) Unmarshal(data []byte, v any) error { func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
vv := messageV2Of(v) vv := messageV2Of(v)
if vv == nil { if vv == nil {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
} }
return proto.Unmarshal(data, vv) buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
defer buf.Free()
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
// really possible without a major overhaul of the proto package, but the
// vtprotobuf library may be able to support this.
return proto.Unmarshal(buf.ReadOnlyData(), vv)
} }
func messageV2Of(v any) proto.Message { func messageV2Of(v any) proto.Message {
@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message {
return nil return nil
} }
func (codec) Name() string { func (c *codecV2) Name() string {
return Name return Name
} }

View File

@ -0,0 +1,270 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"maps"
"testing"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
)
func init() {
internal.SnapshotMetricRegistryForTesting = snapshotMetricsRegistryForTesting
}
var logger = grpclog.Component("metrics-registry")
// DefaultMetrics are the default metrics registered through global metrics
// registry. This is written to at initialization time only, and is read only
// after initialization.
var DefaultMetrics = NewMetrics()
// MetricDescriptor is the data for a registered metric.
type MetricDescriptor struct {
// The name of this metric. This name must be unique across the whole binary
// (including any per call metrics). See
// https://github.com/grpc/proposal/blob/master/A79-non-per-call-metrics-architecture.md#metric-instrument-naming-conventions
// for metric naming conventions.
Name Metric
// The description of this metric.
Description string
// The unit (e.g. entries, seconds) of this metric.
Unit string
// The required label keys for this metric. These are intended to
// metrics emitted from a stats handler.
Labels []string
// The optional label keys for this metric. These are intended to attached
// to metrics emitted from a stats handler if configured.
OptionalLabels []string
// Whether this metric is on by default.
Default bool
// The type of metric. This is set by the metric registry, and not intended
// to be set by a component registering a metric.
Type MetricType
// Bounds are the bounds of this metric. This only applies to histogram
// metrics. If unset or set with length 0, stats handlers will fall back to
// default bounds.
Bounds []float64
}
// MetricType is the type of metric.
type MetricType int
// Type of metric supported by this instrument registry.
const (
MetricTypeIntCount MetricType = iota
MetricTypeFloatCount
MetricTypeIntHisto
MetricTypeFloatHisto
MetricTypeIntGauge
)
// Int64CountHandle is a typed handle for a int count metric. This handle
// is passed at the recording point in order to know which metric to record
// on.
type Int64CountHandle MetricDescriptor
// Descriptor returns the int64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 count value on the metrics recorder provided.
func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Count(h, incr, labels...)
}
// Float64CountHandle is a typed handle for a float count metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Float64CountHandle MetricDescriptor
// Descriptor returns the float64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 count value on the metrics recorder provided.
func (h *Float64CountHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Count(h, incr, labels...)
}
// Int64HistoHandle is a typed handle for an int histogram metric. This handle
// is passed at the recording point in order to know which metric to record on.
type Int64HistoHandle MetricDescriptor
// Descriptor returns the int64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64HistoHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Histo(h, incr, labels...)
}
// Float64HistoHandle is a typed handle for a float histogram metric. This
// handle is passed at the recording point in order to know which metric to
// record on.
type Float64HistoHandle MetricDescriptor
// Descriptor returns the float64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 histo value on the metrics recorder provided.
func (h *Float64HistoHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Histo(h, incr, labels...)
}
// Int64GaugeHandle is a typed handle for an int gauge metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Int64GaugeHandle MetricDescriptor
// Descriptor returns the int64 gauge handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64GaugeHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Gauge(h, incr, labels...)
}
// registeredMetrics are the registered metric descriptor names.
var registeredMetrics = make(map[Metric]bool)
// metricsRegistry contains all of the registered metrics.
//
// This is written to only at init time, and read only after that.
var metricsRegistry = make(map[Metric]*MetricDescriptor)
// DescriptorForMetric returns the MetricDescriptor from the global registry.
//
// Returns nil if MetricDescriptor not present.
func DescriptorForMetric(metric Metric) *MetricDescriptor {
return metricsRegistry[metric]
}
func registerMetric(name Metric, def bool) {
if registeredMetrics[name] {
logger.Fatalf("metric %v already registered", name)
}
registeredMetrics[name] = true
if def {
DefaultMetrics = DefaultMetrics.Add(name)
}
}
// RegisterInt64Count registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Count(descriptor MetricDescriptor) *Int64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64CountHandle)(descPtr)
}
// RegisterFloat64Count registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Count(descriptor MetricDescriptor) *Float64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64CountHandle)(descPtr)
}
// RegisterInt64Histo registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Histo(descriptor MetricDescriptor) *Int64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64HistoHandle)(descPtr)
}
// RegisterFloat64Histo registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Histo(descriptor MetricDescriptor) *Float64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64HistoHandle)(descPtr)
}
// RegisterInt64Gauge registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntGauge
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64GaugeHandle)(descPtr)
}
// snapshotMetricsRegistryForTesting snapshots the global data of the metrics
// registry. Registers a cleanup function on the provided testing.T that sets
// the metrics registry to its original state. Only called in testing functions.
func snapshotMetricsRegistryForTesting(t *testing.T) {
oldDefaultMetrics := DefaultMetrics
oldRegisteredMetrics := registeredMetrics
oldMetricsRegistry := metricsRegistry
registeredMetrics = make(map[Metric]bool)
metricsRegistry = make(map[Metric]*MetricDescriptor)
maps.Copy(registeredMetrics, registeredMetrics)
maps.Copy(metricsRegistry, metricsRegistry)
t.Cleanup(func() {
DefaultMetrics = oldDefaultMetrics
registeredMetrics = oldRegisteredMetrics
metricsRegistry = oldMetricsRegistry
})
}

View File

@ -0,0 +1,114 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats contains experimental metrics/stats API's.
package stats
import "maps"
// MetricsRecorder records on metrics derived from metric registry.
type MetricsRecorder interface {
// RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle.
RecordInt64Count(handle *Int64CountHandle, incr int64, labels ...string)
// RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle.
RecordFloat64Count(handle *Float64CountHandle, incr float64, labels ...string)
// RecordInt64Histo records the measurement alongside labels on the int
// histo associated with the provided handle.
RecordInt64Histo(handle *Int64HistoHandle, incr int64, labels ...string)
// RecordFloat64Histo records the measurement alongside labels on the float
// histo associated with the provided handle.
RecordFloat64Histo(handle *Float64HistoHandle, incr float64, labels ...string)
// RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string)
}
// Metric is an identifier for a metric.
type Metric string
// Metrics is a set of metrics to record. Once created, Metrics is immutable,
// however Add and Remove can make copies with specific metrics added or
// removed, respectively.
//
// Do not construct directly; use NewMetrics instead.
type Metrics struct {
// metrics are the set of metrics to initialize.
metrics map[Metric]bool
}
// NewMetrics returns a Metrics containing Metrics.
func NewMetrics(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for _, metric := range metrics {
newMetrics[metric] = true
}
return &Metrics{
metrics: newMetrics,
}
}
// Metrics returns the metrics set. The returned map is read-only and must not
// be modified.
func (m *Metrics) Metrics() map[Metric]bool {
return m.metrics
}
// Add adds the metrics to the metrics set and returns a new copy with the
// additional metrics.
func (m *Metrics) Add(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for metric := range m.metrics {
newMetrics[metric] = true
}
for _, metric := range metrics {
newMetrics[metric] = true
}
return &Metrics{
metrics: newMetrics,
}
}
// Join joins the metrics passed in with the metrics set, and returns a new copy
// with the merged metrics.
func (m *Metrics) Join(metrics *Metrics) *Metrics {
newMetrics := make(map[Metric]bool)
maps.Copy(newMetrics, m.metrics)
maps.Copy(newMetrics, metrics.metrics)
return &Metrics{
metrics: newMetrics,
}
}
// Remove removes the metrics from the metrics set and returns a new copy with
// the metrics removed.
func (m *Metrics) Remove(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for metric := range m.metrics {
newMetrics[metric] = true
}
for _, metric := range metrics {
delete(newMetrics, metric)
}
return &Metrics{
metrics: newMetrics,
}
}

View File

@ -20,8 +20,6 @@ package grpclog
import ( import (
"fmt" "fmt"
"google.golang.org/grpc/internal/grpclog"
) )
// componentData records the settings for a component. // componentData records the settings for a component.
@ -33,22 +31,22 @@ var cache = map[string]*componentData{}
func (c *componentData) InfoDepth(depth int, args ...any) { func (c *componentData) InfoDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.InfoDepth(depth+1, args...) InfoDepth(depth+1, args...)
} }
func (c *componentData) WarningDepth(depth int, args ...any) { func (c *componentData) WarningDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.WarningDepth(depth+1, args...) WarningDepth(depth+1, args...)
} }
func (c *componentData) ErrorDepth(depth int, args ...any) { func (c *componentData) ErrorDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.ErrorDepth(depth+1, args...) ErrorDepth(depth+1, args...)
} }
func (c *componentData) FatalDepth(depth int, args ...any) { func (c *componentData) FatalDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.FatalDepth(depth+1, args...) FatalDepth(depth+1, args...)
} }
func (c *componentData) Info(args ...any) { func (c *componentData) Info(args ...any) {

View File

@ -18,18 +18,15 @@
// Package grpclog defines logging for grpc. // Package grpclog defines logging for grpc.
// //
// All logs in transport and grpclb packages only go to verbose level 2. // In the default logger, severity level can be set by environment variable
// All logs in other packages in grpc are logged in spite of the verbosity level. // GRPC_GO_LOG_SEVERITY_LEVEL, verbosity level can be set by
// // GRPC_GO_LOG_VERBOSITY_LEVEL.
// In the default logger, package grpclog
// severity level can be set by environment variable GRPC_GO_LOG_SEVERITY_LEVEL,
// verbosity level can be set by GRPC_GO_LOG_VERBOSITY_LEVEL.
package grpclog // import "google.golang.org/grpc/grpclog"
import ( import (
"os" "os"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/grpclog/internal"
) )
func init() { func init() {
@ -38,58 +35,58 @@ func init() {
// V reports whether verbosity level l is at least the requested verbose level. // V reports whether verbosity level l is at least the requested verbose level.
func V(l int) bool { func V(l int) bool {
return grpclog.Logger.V(l) return internal.LoggerV2Impl.V(l)
} }
// Info logs to the INFO log. // Info logs to the INFO log.
func Info(args ...any) { func Info(args ...any) {
grpclog.Logger.Info(args...) internal.LoggerV2Impl.Info(args...)
} }
// Infof logs to the INFO log. Arguments are handled in the manner of fmt.Printf. // Infof logs to the INFO log. Arguments are handled in the manner of fmt.Printf.
func Infof(format string, args ...any) { func Infof(format string, args ...any) {
grpclog.Logger.Infof(format, args...) internal.LoggerV2Impl.Infof(format, args...)
} }
// Infoln logs to the INFO log. Arguments are handled in the manner of fmt.Println. // Infoln logs to the INFO log. Arguments are handled in the manner of fmt.Println.
func Infoln(args ...any) { func Infoln(args ...any) {
grpclog.Logger.Infoln(args...) internal.LoggerV2Impl.Infoln(args...)
} }
// Warning logs to the WARNING log. // Warning logs to the WARNING log.
func Warning(args ...any) { func Warning(args ...any) {
grpclog.Logger.Warning(args...) internal.LoggerV2Impl.Warning(args...)
} }
// Warningf logs to the WARNING log. Arguments are handled in the manner of fmt.Printf. // Warningf logs to the WARNING log. Arguments are handled in the manner of fmt.Printf.
func Warningf(format string, args ...any) { func Warningf(format string, args ...any) {
grpclog.Logger.Warningf(format, args...) internal.LoggerV2Impl.Warningf(format, args...)
} }
// Warningln logs to the WARNING log. Arguments are handled in the manner of fmt.Println. // Warningln logs to the WARNING log. Arguments are handled in the manner of fmt.Println.
func Warningln(args ...any) { func Warningln(args ...any) {
grpclog.Logger.Warningln(args...) internal.LoggerV2Impl.Warningln(args...)
} }
// Error logs to the ERROR log. // Error logs to the ERROR log.
func Error(args ...any) { func Error(args ...any) {
grpclog.Logger.Error(args...) internal.LoggerV2Impl.Error(args...)
} }
// Errorf logs to the ERROR log. Arguments are handled in the manner of fmt.Printf. // Errorf logs to the ERROR log. Arguments are handled in the manner of fmt.Printf.
func Errorf(format string, args ...any) { func Errorf(format string, args ...any) {
grpclog.Logger.Errorf(format, args...) internal.LoggerV2Impl.Errorf(format, args...)
} }
// Errorln logs to the ERROR log. Arguments are handled in the manner of fmt.Println. // Errorln logs to the ERROR log. Arguments are handled in the manner of fmt.Println.
func Errorln(args ...any) { func Errorln(args ...any) {
grpclog.Logger.Errorln(args...) internal.LoggerV2Impl.Errorln(args...)
} }
// Fatal logs to the FATAL log. Arguments are handled in the manner of fmt.Print. // Fatal logs to the FATAL log. Arguments are handled in the manner of fmt.Print.
// It calls os.Exit() with exit code 1. // It calls os.Exit() with exit code 1.
func Fatal(args ...any) { func Fatal(args ...any) {
grpclog.Logger.Fatal(args...) internal.LoggerV2Impl.Fatal(args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
@ -97,15 +94,15 @@ func Fatal(args ...any) {
// Fatalf logs to the FATAL log. Arguments are handled in the manner of fmt.Printf. // Fatalf logs to the FATAL log. Arguments are handled in the manner of fmt.Printf.
// It calls os.Exit() with exit code 1. // It calls os.Exit() with exit code 1.
func Fatalf(format string, args ...any) { func Fatalf(format string, args ...any) {
grpclog.Logger.Fatalf(format, args...) internal.LoggerV2Impl.Fatalf(format, args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
// Fatalln logs to the FATAL log. Arguments are handled in the manner of fmt.Println. // Fatalln logs to the FATAL log. Arguments are handled in the manner of fmt.Println.
// It calle os.Exit()) with exit code 1. // It calls os.Exit() with exit code 1.
func Fatalln(args ...any) { func Fatalln(args ...any) {
grpclog.Logger.Fatalln(args...) internal.LoggerV2Impl.Fatalln(args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
@ -114,19 +111,76 @@ func Fatalln(args ...any) {
// //
// Deprecated: use Info. // Deprecated: use Info.
func Print(args ...any) { func Print(args ...any) {
grpclog.Logger.Info(args...) internal.LoggerV2Impl.Info(args...)
} }
// Printf prints to the logger. Arguments are handled in the manner of fmt.Printf. // Printf prints to the logger. Arguments are handled in the manner of fmt.Printf.
// //
// Deprecated: use Infof. // Deprecated: use Infof.
func Printf(format string, args ...any) { func Printf(format string, args ...any) {
grpclog.Logger.Infof(format, args...) internal.LoggerV2Impl.Infof(format, args...)
} }
// Println prints to the logger. Arguments are handled in the manner of fmt.Println. // Println prints to the logger. Arguments are handled in the manner of fmt.Println.
// //
// Deprecated: use Infoln. // Deprecated: use Infoln.
func Println(args ...any) { func Println(args ...any) {
grpclog.Logger.Infoln(args...) internal.LoggerV2Impl.Infoln(args...)
}
// InfoDepth logs to the INFO log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func InfoDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.InfoDepth(depth, args...)
} else {
internal.LoggerV2Impl.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WarningDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.WarningDepth(depth, args...)
} else {
internal.LoggerV2Impl.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ErrorDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.ErrorDepth(depth, args...)
} else {
internal.LoggerV2Impl.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func FatalDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.FatalDepth(depth, args...)
} else {
internal.LoggerV2Impl.Fatalln(args...)
}
os.Exit(1)
} }

View File

@ -0,0 +1,26 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains functionality internal to the grpclog package.
package internal
// LoggerV2Impl is the logger used for the non-depth log functions.
var LoggerV2Impl LoggerV2
// DepthLoggerV2Impl is the logger used for the depth log functions.
var DepthLoggerV2Impl DepthLoggerV2

View File

@ -0,0 +1,87 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
// Logger mimics golang's standard Logger as an interface.
//
// Deprecated: use LoggerV2.
type Logger interface {
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
// LoggerWrapper wraps Logger into a LoggerV2.
type LoggerWrapper struct {
Logger
}
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Info(args ...any) {
l.Logger.Print(args...)
}
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Infoln(args ...any) {
l.Logger.Println(args...)
}
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Infof(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Warning(args ...any) {
l.Logger.Print(args...)
}
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Warningln(args ...any) {
l.Logger.Println(args...)
}
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Warningf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Error(args ...any) {
l.Logger.Print(args...)
}
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Errorln(args ...any) {
l.Logger.Println(args...)
}
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Errorf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// V reports whether verbosity level l is at least the requested verbose level.
func (*LoggerWrapper) V(l int) bool {
// Returns true for all verbose level.
return true
}

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2020 gRPC authors. * Copyright 2024 gRPC authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,59 +16,17 @@
* *
*/ */
// Package grpclog (internal) defines depth logging for grpc. package internal
package grpclog
import ( import (
"encoding/json"
"fmt"
"io"
"log"
"os" "os"
) )
// Logger is the logger used for the non-depth log functions.
var Logger LoggerV2
// DepthLogger is the logger used for the depth log functions.
var DepthLogger DepthLoggerV2
// InfoDepth logs to the INFO log at the specified depth.
func InfoDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.InfoDepth(depth, args...)
} else {
Logger.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
func WarningDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.WarningDepth(depth, args...)
} else {
Logger.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
func ErrorDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.ErrorDepth(depth, args...)
} else {
Logger.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
func FatalDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.FatalDepth(depth, args...)
} else {
Logger.Fatalln(args...)
}
os.Exit(1)
}
// LoggerV2 does underlying logging work for grpclog. // LoggerV2 does underlying logging work for grpclog.
// This is a copy of the LoggerV2 defined in the external grpclog package. It
// is defined here to avoid a circular dependency.
type LoggerV2 interface { type LoggerV2 interface {
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. // Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any) Info(args ...any)
@ -107,14 +65,13 @@ type LoggerV2 interface {
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements // DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack // DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore. // depth set for trivial functions the logger may ignore.
// This is a copy of the DepthLoggerV2 defined in the external grpclog package.
// It is defined here to avoid a circular dependency.
// //
// # Experimental // # Experimental
// //
// Notice: This type is EXPERIMENTAL and may be changed or removed in a // Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
type DepthLoggerV2 interface { type DepthLoggerV2 interface {
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println. // InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any) InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println. // WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
@ -124,3 +81,124 @@ type DepthLoggerV2 interface {
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println. // FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any) FatalDepth(depth int, args ...any)
} }
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
}
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s))
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) Info(args ...any) {
g.output(infoLog, fmt.Sprint(args...))
}
func (g *loggerT) Infoln(args ...any) {
g.output(infoLog, fmt.Sprintln(args...))
}
func (g *loggerT) Infof(format string, args ...any) {
g.output(infoLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Warning(args ...any) {
g.output(warningLog, fmt.Sprint(args...))
}
func (g *loggerT) Warningln(args ...any) {
g.output(warningLog, fmt.Sprintln(args...))
}
func (g *loggerT) Warningf(format string, args ...any) {
g.output(warningLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Error(args ...any) {
g.output(errorLog, fmt.Sprint(args...))
}
func (g *loggerT) Errorln(args ...any) {
g.output(errorLog, fmt.Sprintln(args...))
}
func (g *loggerT) Errorf(format string, args ...any) {
g.output(errorLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Fatal(args ...any) {
g.output(fatalLog, fmt.Sprint(args...))
os.Exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.output(fatalLog, fmt.Sprintln(args...))
os.Exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.output(fatalLog, fmt.Sprintf(format, args...))
os.Exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// LoggerV2Config configures the LoggerV2 implementation.
type LoggerV2Config struct {
// Verbosity sets the verbosity level of the logger.
Verbosity int
// FormatJSON controls whether the logger should output logs in JSON format.
FormatJSON bool
}
// NewLoggerV2 creates a new LoggerV2 instance with the provided configuration.
// The infoW, warningW, and errorW writers are used to write log messages of
// different severity levels.
func NewLoggerV2(infoW, warningW, errorW io.Writer, c LoggerV2Config) LoggerV2 {
var m []*log.Logger
flag := log.LstdFlags
if c.FormatJSON {
flag = 0
}
m = append(m, log.New(infoW, "", flag))
m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag))
ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal.
m = append(m, log.New(ew, "", flag))
m = append(m, log.New(ew, "", flag))
return &loggerT{m: m, v: c.Verbosity, jsonFormat: c.FormatJSON}
}

View File

@ -18,70 +18,17 @@
package grpclog package grpclog
import "google.golang.org/grpc/internal/grpclog" import "google.golang.org/grpc/grpclog/internal"
// Logger mimics golang's standard Logger as an interface. // Logger mimics golang's standard Logger as an interface.
// //
// Deprecated: use LoggerV2. // Deprecated: use LoggerV2.
type Logger interface { type Logger internal.Logger
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
// SetLogger sets the logger that is used in grpc. Call only from // SetLogger sets the logger that is used in grpc. Call only from
// init() functions. // init() functions.
// //
// Deprecated: use SetLoggerV2. // Deprecated: use SetLoggerV2.
func SetLogger(l Logger) { func SetLogger(l Logger) {
grpclog.Logger = &loggerWrapper{Logger: l} internal.LoggerV2Impl = &internal.LoggerWrapper{Logger: l}
}
// loggerWrapper wraps Logger into a LoggerV2.
type loggerWrapper struct {
Logger
}
func (g *loggerWrapper) Info(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Infoln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Infof(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Warning(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Warningln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Warningf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Error(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Errorln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Errorf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) V(l int) bool {
// Returns true for all verbose level.
return true
} }

View File

@ -19,52 +19,16 @@
package grpclog package grpclog
import ( import (
"encoding/json"
"fmt"
"io" "io"
"log"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/grpclog/internal"
) )
// LoggerV2 does underlying logging work for grpclog. // LoggerV2 does underlying logging work for grpclog.
type LoggerV2 interface { type LoggerV2 internal.LoggerV2
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any)
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
Infoln(args ...any)
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
Infof(format string, args ...any)
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
Warning(args ...any)
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
Warningln(args ...any)
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
Warningf(format string, args ...any)
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
Error(args ...any)
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
Errorln(args ...any)
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
Errorf(format string, args ...any)
// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatal(args ...any)
// Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalln(args ...any)
// Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalf(format string, args ...any)
// V reports whether verbosity level l is at least the requested verbose level.
V(l int) bool
}
// SetLoggerV2 sets logger that is used in grpc to a V2 logger. // SetLoggerV2 sets logger that is used in grpc to a V2 logger.
// Not mutex-protected, should be called before any gRPC functions. // Not mutex-protected, should be called before any gRPC functions.
@ -72,34 +36,8 @@ func SetLoggerV2(l LoggerV2) {
if _, ok := l.(*componentData); ok { if _, ok := l.(*componentData); ok {
panic("cannot use component logger as grpclog logger") panic("cannot use component logger as grpclog logger")
} }
grpclog.Logger = l internal.LoggerV2Impl = l
grpclog.DepthLogger, _ = l.(grpclog.DepthLoggerV2) internal.DepthLoggerV2Impl, _ = l.(internal.DepthLoggerV2)
}
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
} }
// NewLoggerV2 creates a loggerV2 with the provided writers. // NewLoggerV2 creates a loggerV2 with the provided writers.
@ -108,32 +46,13 @@ type loggerT struct {
// Warning logs will be written to warningW and infoW. // Warning logs will be written to warningW and infoW.
// Info logs will be written to infoW. // Info logs will be written to infoW.
func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 { func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{}) return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{})
} }
// NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and // NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and
// verbosity level. // verbosity level.
func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 { func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{verbose: v}) return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{Verbosity: v})
}
type loggerV2Config struct {
verbose int
jsonFormat bool
}
func newLoggerV2WithConfig(infoW, warningW, errorW io.Writer, c loggerV2Config) LoggerV2 {
var m []*log.Logger
flag := log.LstdFlags
if c.jsonFormat {
flag = 0
}
m = append(m, log.New(infoW, "", flag))
m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag))
ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal.
m = append(m, log.New(ew, "", flag))
m = append(m, log.New(ew, "", flag))
return &loggerT{m: m, v: c.verbose, jsonFormat: c.jsonFormat}
} }
// newLoggerV2 creates a loggerV2 to be used as default logger. // newLoggerV2 creates a loggerV2 to be used as default logger.
@ -161,82 +80,12 @@ func newLoggerV2() LoggerV2 {
jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json") jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json")
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{ return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{
verbose: v, Verbosity: v,
jsonFormat: jsonFormat, FormatJSON: jsonFormat,
}) })
} }
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s))
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) Info(args ...any) {
g.output(infoLog, fmt.Sprint(args...))
}
func (g *loggerT) Infoln(args ...any) {
g.output(infoLog, fmt.Sprintln(args...))
}
func (g *loggerT) Infof(format string, args ...any) {
g.output(infoLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Warning(args ...any) {
g.output(warningLog, fmt.Sprint(args...))
}
func (g *loggerT) Warningln(args ...any) {
g.output(warningLog, fmt.Sprintln(args...))
}
func (g *loggerT) Warningf(format string, args ...any) {
g.output(warningLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Error(args ...any) {
g.output(errorLog, fmt.Sprint(args...))
}
func (g *loggerT) Errorln(args ...any) {
g.output(errorLog, fmt.Sprintln(args...))
}
func (g *loggerT) Errorf(format string, args ...any) {
g.output(errorLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Fatal(args ...any) {
g.output(fatalLog, fmt.Sprint(args...))
os.Exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.output(fatalLog, fmt.Sprintln(args...))
os.Exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.output(fatalLog, fmt.Sprintf(format, args...))
os.Exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements // DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack // DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore. // depth set for trivial functions the logger may ignore.
@ -245,14 +94,4 @@ func (g *loggerT) V(l int) bool {
// //
// Notice: This type is EXPERIMENTAL and may be changed or removed in a // Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
type DepthLoggerV2 interface { type DepthLoggerV2 internal.DepthLoggerV2
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
WarningDepth(depth int, args ...any)
// ErrorDepth logs to ERROR log at the specified depth. Arguments are handled in the manner of fmt.Println.
ErrorDepth(depth int, args ...any)
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any)
}

View File

@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.34.1 // protoc-gen-go v1.34.1
// protoc v4.25.2 // protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
package grpc_health_v1 package grpc_health_v1

View File

@ -17,8 +17,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.4.0 // - protoc-gen-go-grpc v1.5.1
// - protoc v4.25.2 // - protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
package grpc_health_v1 package grpc_health_v1
@ -32,8 +32,8 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.62.0 or later. // Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion8 const _ = grpc.SupportPackageIsVersion9
const ( const (
Health_Check_FullMethodName = "/grpc.health.v1.Health/Check" Health_Check_FullMethodName = "/grpc.health.v1.Health/Check"
@ -73,7 +73,7 @@ type HealthClient interface {
// should assume this method is not supported and should not retry the // should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK), // call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff. // clients should retry the call with appropriate exponential backoff.
Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error)
} }
type healthClient struct { type healthClient struct {
@ -94,13 +94,13 @@ func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts .
return out, nil return out, nil
} }
func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) { func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, cOpts...) stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, cOpts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
x := &healthWatchClient{ClientStream: stream} x := &grpc.GenericClientStream[HealthCheckRequest, HealthCheckResponse]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil { if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err return nil, err
} }
@ -110,26 +110,12 @@ func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts .
return x, nil return x, nil
} }
type Health_WatchClient interface { // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
Recv() (*HealthCheckResponse, error) type Health_WatchClient = grpc.ServerStreamingClient[HealthCheckResponse]
grpc.ClientStream
}
type healthWatchClient struct {
grpc.ClientStream
}
func (x *healthWatchClient) Recv() (*HealthCheckResponse, error) {
m := new(HealthCheckResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// HealthServer is the server API for Health service. // HealthServer is the server API for Health service.
// All implementations should embed UnimplementedHealthServer // All implementations should embed UnimplementedHealthServer
// for forward compatibility // for forward compatibility.
// //
// Health is gRPC's mechanism for checking whether a server is able to handle // Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in // RPCs. Its semantics are documented in
@ -160,19 +146,23 @@ type HealthServer interface {
// should assume this method is not supported and should not retry the // should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK), // call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff. // clients should retry the call with appropriate exponential backoff.
Watch(*HealthCheckRequest, Health_WatchServer) error Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error
} }
// UnimplementedHealthServer should be embedded to have forward compatible implementations. // UnimplementedHealthServer should be embedded to have
type UnimplementedHealthServer struct { // forward compatible implementations.
} //
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedHealthServer struct{}
func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) { func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Check not implemented") return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
} }
func (UnimplementedHealthServer) Watch(*HealthCheckRequest, Health_WatchServer) error { func (UnimplementedHealthServer) Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error {
return status.Errorf(codes.Unimplemented, "method Watch not implemented") return status.Errorf(codes.Unimplemented, "method Watch not implemented")
} }
func (UnimplementedHealthServer) testEmbeddedByValue() {}
// UnsafeHealthServer may be embedded to opt out of forward compatibility for this service. // UnsafeHealthServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to HealthServer will // Use of this interface is not recommended, as added methods to HealthServer will
@ -182,6 +172,13 @@ type UnsafeHealthServer interface {
} }
func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) { func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) {
// If the following call panics, it indicates UnimplementedHealthServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&Health_ServiceDesc, srv) s.RegisterService(&Health_ServiceDesc, srv)
} }
@ -208,21 +205,11 @@ func _Health_Watch_Handler(srv interface{}, stream grpc.ServerStream) error {
if err := stream.RecvMsg(m); err != nil { if err := stream.RecvMsg(m); err != nil {
return err return err
} }
return srv.(HealthServer).Watch(m, &healthWatchServer{ServerStream: stream}) return srv.(HealthServer).Watch(m, &grpc.GenericServerStream[HealthCheckRequest, HealthCheckResponse]{ServerStream: stream})
} }
type Health_WatchServer interface { // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
Send(*HealthCheckResponse) error type Health_WatchServer = grpc.ServerStreamingServer[HealthCheckResponse]
grpc.ServerStream
}
type healthWatchServer struct {
grpc.ServerStream
}
func (x *healthWatchServer) Send(m *HealthCheckResponse) error {
return x.ServerStream.SendMsg(m)
}
// Health_ServiceDesc is the grpc.ServiceDesc for Health service. // Health_ServiceDesc is the grpc.ServiceDesc for Health service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,

View File

@ -46,7 +46,7 @@ type entry interface {
// channelMap is the storage data structure for channelz. // channelMap is the storage data structure for channelz.
// //
// Methods of channelMap can be divided in two two categories with respect to // Methods of channelMap can be divided into two categories with respect to
// locking. // locking.
// //
// 1. Methods acquire the global lock. // 1. Methods acquire the global lock.

View File

@ -46,6 +46,10 @@ var (
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true" // by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false". // or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false) EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
// XDSFallbackSupport is the env variable that controls whether support for
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", false)
) )
func boolFromEnv(envVar string, def bool) bool { func boolFromEnv(envVar string, def bool) bool {

View File

@ -18,11 +18,11 @@
package internal package internal
var ( var (
// WithRecvBufferPool is implemented by the grpc package and returns a dial // WithBufferPool is implemented by the grpc package and returns a dial
// option to configure a shared buffer pool for a grpc.ClientConn. // option to configure a shared buffer pool for a grpc.ClientConn.
WithRecvBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption WithBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
// RecvBufferPool is implemented by the grpc package and returns a server // BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server. // option to configure a shared buffer pool for a grpc.Server.
RecvBufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
) )

View File

@ -16,17 +16,21 @@
* *
*/ */
// Package grpclog provides logging functionality for internal gRPC packages,
// outside of the functionality provided by the external `grpclog` package.
package grpclog package grpclog
import ( import (
"fmt" "fmt"
"google.golang.org/grpc/grpclog"
) )
// PrefixLogger does logging with a prefix. // PrefixLogger does logging with a prefix.
// //
// Logging method on a nil logs without any prefix. // Logging method on a nil logs without any prefix.
type PrefixLogger struct { type PrefixLogger struct {
logger DepthLoggerV2 logger grpclog.DepthLoggerV2
prefix string prefix string
} }
@ -38,7 +42,7 @@ func (pl *PrefixLogger) Infof(format string, args ...any) {
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...)) pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return return
} }
InfoDepth(1, fmt.Sprintf(format, args...)) grpclog.InfoDepth(1, fmt.Sprintf(format, args...))
} }
// Warningf does warning logging. // Warningf does warning logging.
@ -48,7 +52,7 @@ func (pl *PrefixLogger) Warningf(format string, args ...any) {
pl.logger.WarningDepth(1, fmt.Sprintf(format, args...)) pl.logger.WarningDepth(1, fmt.Sprintf(format, args...))
return return
} }
WarningDepth(1, fmt.Sprintf(format, args...)) grpclog.WarningDepth(1, fmt.Sprintf(format, args...))
} }
// Errorf does error logging. // Errorf does error logging.
@ -58,36 +62,18 @@ func (pl *PrefixLogger) Errorf(format string, args ...any) {
pl.logger.ErrorDepth(1, fmt.Sprintf(format, args...)) pl.logger.ErrorDepth(1, fmt.Sprintf(format, args...))
return return
} }
ErrorDepth(1, fmt.Sprintf(format, args...)) grpclog.ErrorDepth(1, fmt.Sprintf(format, args...))
}
// Debugf does info logging at verbose level 2.
func (pl *PrefixLogger) Debugf(format string, args ...any) {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe
// rewrite PrefixLogger a little to ensure that we don't use the global
// `Logger` here, and instead use the `logger` field.
if !Logger.V(2) {
return
}
if pl != nil {
// Handle nil, so the tests can pass in a nil logger.
format = pl.prefix + format
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return
}
InfoDepth(1, fmt.Sprintf(format, args...))
} }
// V reports whether verbosity level l is at least the requested verbose level. // V reports whether verbosity level l is at least the requested verbose level.
func (pl *PrefixLogger) V(l int) bool { func (pl *PrefixLogger) V(l int) bool {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe if pl != nil {
// rewrite PrefixLogger a little to ensure that we don't use the global return pl.logger.V(l)
// `Logger` here, and instead use the `logger` field. }
return Logger.V(l) return true
} }
// NewPrefixLogger creates a prefix logger with the given prefix. // NewPrefixLogger creates a prefix logger with the given prefix.
func NewPrefixLogger(logger DepthLoggerV2, prefix string) *PrefixLogger { func NewPrefixLogger(logger grpclog.DepthLoggerV2, prefix string) *PrefixLogger {
return &PrefixLogger{logger: logger, prefix: prefix} return &PrefixLogger{logger: logger, prefix: prefix}
} }

View File

@ -53,16 +53,28 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
return cs return cs
} }
// Schedule adds a callback to be scheduled after existing callbacks are run. // TrySchedule tries to schedules the provided callback function f to be
// executed in the order it was added. This is a best-effort operation. If the
// context passed to NewCallbackSerializer was canceled before this method is
// called, the callback will not be scheduled.
// //
// Callbacks are expected to honor the context when performing any blocking // Callbacks are expected to honor the context when performing any blocking
// operations, and should return early when the context is canceled. // operations, and should return early when the context is canceled.
func (cs *CallbackSerializer) TrySchedule(f func(ctx context.Context)) {
cs.callbacks.Put(f)
}
// ScheduleOr schedules the provided callback function f to be executed in the
// order it was added. If the context passed to NewCallbackSerializer has been
// canceled before this method is called, the onFailure callback will be
// executed inline instead.
// //
// Return value indicates if the callback was successfully added to the list of // Callbacks are expected to honor the context when performing any blocking
// callbacks to be executed by the serializer. It is not possible to add // operations, and should return early when the context is canceled.
// callbacks once the context passed to NewCallbackSerializer is cancelled. func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure func()) {
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool { if cs.callbacks.Put(f) != nil {
return cs.callbacks.Put(f) == nil onFailure()
}
} }
func (cs *CallbackSerializer) run(ctx context.Context) { func (cs *CallbackSerializer) run(ctx context.Context) {

View File

@ -77,7 +77,7 @@ func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) {
if ps.msg != nil { if ps.msg != nil {
msg := ps.msg msg := ps.msg
ps.cs.Schedule(func(context.Context) { ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock() ps.mu.Lock()
defer ps.mu.Unlock() defer ps.mu.Unlock()
if !ps.subscribers[sub] { if !ps.subscribers[sub] {
@ -103,7 +103,7 @@ func (ps *PubSub) Publish(msg any) {
ps.msg = msg ps.msg = msg
for sub := range ps.subscribers { for sub := range ps.subscribers {
s := sub s := sub
ps.cs.Schedule(func(context.Context) { ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock() ps.mu.Lock()
defer ps.mu.Unlock() defer ps.mu.Unlock()
if !ps.subscribers[s] { if !ps.subscribers[s] {

View File

@ -208,6 +208,27 @@ var (
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n // ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j. // is the number of elements. swap swaps the elements with indexes i and j.
ShuffleAddressListForTesting any // func(n int, swap func(i, j int)) ShuffleAddressListForTesting any // func(n int, swap func(i, j int))
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
// SnapshotMetricRegistryForTesting snapshots the global data of the metric
// registry. Registers a cleanup function on the provided testing.T that
// sets the metric registry to its original state. Only called in testing
// functions.
SnapshotMetricRegistryForTesting any // func(t *testing.T)
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
// testing purposes.
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes.
SetBufferPoolingThresholdForTesting any // func(int)
) )
// HealthChecker defines the signature of the client-side LB channel health // HealthChecker defines the signature of the client-side LB channel health

42
vendor/google.golang.org/grpc/internal/stats/labels.go generated vendored Normal file
View File

@ -0,0 +1,42 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats provides internal stats related functionality.
package stats
import "context"
// Labels are the labels for metrics.
type Labels struct {
// TelemetryLabels are the telemetry labels to record.
TelemetryLabels map[string]string
}
type labelsKey struct{}
// GetLabels returns the Labels stored in the context, or nil if there is one.
func GetLabels(ctx context.Context) *Labels {
labels, _ := ctx.Value(labelsKey{}).(*Labels)
return labels
}
// SetLabels sets the Labels in the context.
func SetLabels(ctx context.Context, labels *Labels) context.Context {
// could also append
return context.WithValue(ctx, labelsKey{}, labels)
}

View File

@ -0,0 +1,95 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package stats
import (
"fmt"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/stats"
)
// MetricsRecorderList forwards Record calls to all of its metricsRecorders.
//
// It eats any record calls where the label values provided do not match the
// number of label keys.
type MetricsRecorderList struct {
// metricsRecorders are the metrics recorders this list will forward to.
metricsRecorders []estats.MetricsRecorder
}
// NewMetricsRecorderList creates a new metric recorder list with all the stats
// handlers provided which implement the MetricsRecorder interface.
// If no stats handlers provided implement the MetricsRecorder interface,
// the MetricsRecorder list returned is a no-op.
func NewMetricsRecorderList(shs []stats.Handler) *MetricsRecorderList {
var mrs []estats.MetricsRecorder
for _, sh := range shs {
if mr, ok := sh.(estats.MetricsRecorder); ok {
mrs = append(mrs, mr)
}
}
return &MetricsRecorderList{
metricsRecorders: mrs,
}
}
func verifyLabels(desc *estats.MetricDescriptor, labelsRecv ...string) {
if got, want := len(labelsRecv), len(desc.Labels)+len(desc.OptionalLabels); got != want {
panic(fmt.Sprintf("Received %d labels in call to record metric %q, but expected %d.", got, desc.Name, want))
}
}
func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Count(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Count(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordInt64Histo(handle *estats.Int64HistoHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Histo(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordFloat64Histo(handle *estats.Float64HistoHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Histo(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Gauge(handle, incr, labels...)
}
}

View File

@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and // combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the // disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for // KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters. // the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error { Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1) unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1)

View File

@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and // combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the // disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for // KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters. // the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error { Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, 1) windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, 1)

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -148,9 +149,9 @@ type dataFrame struct {
streamID uint32 streamID uint32
endStream bool endStream bool
h []byte h []byte
d []byte reader mem.Reader
// onEachWrite is called every time // onEachWrite is called every time
// a part of d is written out. // a part of data is written out.
onEachWrite func() onEachWrite func()
} }
@ -289,18 +290,22 @@ func (l *outStreamList) dequeue() *outStream {
} }
// controlBuffer is a way to pass information to loopy. // controlBuffer is a way to pass information to loopy.
// Information is passed as specific struct types called control frames. //
// A control frame not only represents data, messages or headers to be sent out // Information is passed as specific struct types called control frames. A
// but can also be used to instruct loopy to update its internal state. // control frame not only represents data, messages or headers to be sent out
// It shouldn't be confused with an HTTP2 frame, although some of the control frames // but can also be used to instruct loopy to update its internal state. It
// like dataFrame and headerFrame do go out on wire as HTTP2 frames. // shouldn't be confused with an HTTP2 frame, although some of the control
// frames like dataFrame and headerFrame do go out on wire as HTTP2 frames.
type controlBuffer struct { type controlBuffer struct {
ch chan struct{} wakeupCh chan struct{} // Unblocks readers waiting for something to read.
done <-chan struct{} done <-chan struct{} // Closed when the transport is done.
// Mutex guards all the fields below, except trfChan which can be read
// atomically without holding mu.
mu sync.Mutex mu sync.Mutex
consumerWaiting bool consumerWaiting bool // True when readers are blocked waiting for new data.
list *itemList closed bool // True when the controlbuf is finished.
err error list *itemList // List of queued control frames.
// transportResponseFrames counts the number of queued items that represent // transportResponseFrames counts the number of queued items that represent
// the response of an action initiated by the peer. trfChan is created // the response of an action initiated by the peer. trfChan is created
@ -308,47 +313,59 @@ type controlBuffer struct {
// closed and nilled when transportResponseFrames drops below the // closed and nilled when transportResponseFrames drops below the
// threshold. Both fields are protected by mu. // threshold. Both fields are protected by mu.
transportResponseFrames int transportResponseFrames int
trfChan atomic.Value // chan struct{} trfChan atomic.Pointer[chan struct{}]
} }
func newControlBuffer(done <-chan struct{}) *controlBuffer { func newControlBuffer(done <-chan struct{}) *controlBuffer {
return &controlBuffer{ return &controlBuffer{
ch: make(chan struct{}, 1), wakeupCh: make(chan struct{}, 1),
list: &itemList{}, list: &itemList{},
done: done, done: done,
} }
} }
// throttle blocks if there are too many incomingSettings/cleanupStreams in the // throttle blocks if there are too many frames in the control buf that
// controlbuf. // represent the response of an action initiated by the peer, like
// incomingSettings cleanupStreams etc.
func (c *controlBuffer) throttle() { func (c *controlBuffer) throttle() {
ch, _ := c.trfChan.Load().(chan struct{}) if ch := c.trfChan.Load(); ch != nil {
if ch != nil {
select { select {
case <-ch: case <-(*ch):
case <-c.done: case <-c.done:
} }
} }
} }
// put adds an item to the controlbuf.
func (c *controlBuffer) put(it cbItem) error { func (c *controlBuffer) put(it cbItem) error {
_, err := c.executeAndPut(nil, it) _, err := c.executeAndPut(nil, it)
return err return err
} }
// executeAndPut runs f, and if the return value is true, adds the given item to
// the controlbuf. The item could be nil, in which case, this method simply
// executes f and does not add the item to the controlbuf.
//
// The first return value indicates whether the item was successfully added to
// the control buffer. A non-nil error, specifically ErrConnClosing, is returned
// if the control buffer is already closed.
func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) { func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) {
var wakeUp bool
c.mu.Lock() c.mu.Lock()
if c.err != nil { defer c.mu.Unlock()
c.mu.Unlock()
return false, c.err if c.closed {
return false, ErrConnClosing
} }
if f != nil { if f != nil {
if !f() { // f wasn't successful if !f() { // f wasn't successful
c.mu.Unlock()
return false, nil return false, nil
} }
} }
if it == nil {
return true, nil
}
var wakeUp bool
if c.consumerWaiting { if c.consumerWaiting {
wakeUp = true wakeUp = true
c.consumerWaiting = false c.consumerWaiting = false
@ -359,98 +376,102 @@ func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) {
if c.transportResponseFrames == maxQueuedTransportResponseFrames { if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are adding the frame that puts us over the threshold; create // We are adding the frame that puts us over the threshold; create
// a throttling channel. // a throttling channel.
c.trfChan.Store(make(chan struct{})) ch := make(chan struct{})
c.trfChan.Store(&ch)
} }
} }
c.mu.Unlock()
if wakeUp { if wakeUp {
select { select {
case c.ch <- struct{}{}: case c.wakeupCh <- struct{}{}:
default: default:
} }
} }
return true, nil return true, nil
} }
// Note argument f should never be nil. // get returns the next control frame from the control buffer. If block is true
func (c *controlBuffer) execute(f func(it any) bool, it any) (bool, error) { // **and** there are no control frames in the control buffer, the call blocks
c.mu.Lock() // until one of the conditions is met: there is a frame to return or the
if c.err != nil { // transport is closed.
c.mu.Unlock()
return false, c.err
}
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
c.mu.Unlock()
return true, nil
}
func (c *controlBuffer) get(block bool) (any, error) { func (c *controlBuffer) get(block bool) (any, error) {
for { for {
c.mu.Lock() c.mu.Lock()
if c.err != nil { frame, err := c.getOnceLocked()
if frame != nil || err != nil || !block {
// If we read a frame or an error, we can return to the caller. The
// call to getOnceLocked() returns a nil frame and a nil error if
// there is nothing to read, and in that case, if the caller asked
// us not to block, we can return now as well.
c.mu.Unlock() c.mu.Unlock()
return nil, c.err return frame, err
}
if !c.list.isEmpty() {
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Load().(chan struct{})
close(ch)
c.trfChan.Store((chan struct{})(nil))
}
c.transportResponseFrames--
}
c.mu.Unlock()
return h, nil
}
if !block {
c.mu.Unlock()
return nil, nil
} }
c.consumerWaiting = true c.consumerWaiting = true
c.mu.Unlock() c.mu.Unlock()
// Release the lock above and wait to be woken up.
select { select {
case <-c.ch: case <-c.wakeupCh:
case <-c.done: case <-c.done:
return nil, errors.New("transport closed by client") return nil, errors.New("transport closed by client")
} }
} }
} }
// Callers must not use this method, but should instead use get().
//
// Caller must hold c.mu.
func (c *controlBuffer) getOnceLocked() (any, error) {
if c.closed {
return false, ErrConnClosing
}
if c.list.isEmpty() {
return nil, nil
}
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Swap(nil)
close(*ch)
}
c.transportResponseFrames--
}
return h, nil
}
// finish closes the control buffer, cleaning up any streams that have queued
// header frames. Once this method returns, no more frames can be added to the
// control buffer, and attempts to do so will return ErrConnClosing.
func (c *controlBuffer) finish() { func (c *controlBuffer) finish() {
c.mu.Lock() c.mu.Lock()
if c.err != nil { defer c.mu.Unlock()
c.mu.Unlock()
if c.closed {
return return
} }
c.err = ErrConnClosing c.closed = true
// There may be headers for streams in the control buffer. // There may be headers for streams in the control buffer.
// These streams need to be cleaned out since the transport // These streams need to be cleaned out since the transport
// is still not aware of these yet. // is still not aware of these yet.
for head := c.list.dequeueAll(); head != nil; head = head.next { for head := c.list.dequeueAll(); head != nil; head = head.next {
hdr, ok := head.it.(*headerFrame) switch v := head.it.(type) {
if !ok { case *headerFrame:
continue if v.onOrphaned != nil { // It will be nil on the server-side.
} v.onOrphaned(ErrConnClosing)
if hdr.onOrphaned != nil { // It will be nil on the server-side. }
hdr.onOrphaned(ErrConnClosing) case *dataFrame:
_ = v.reader.Close()
} }
} }
// In case throttle() is currently in flight, it needs to be unblocked. // In case throttle() is currently in flight, it needs to be unblocked.
// Otherwise, the transport may not close, since the transport is closed by // Otherwise, the transport may not close, since the transport is closed by
// the reader encountering the connection error. // the reader encountering the connection error.
ch, _ := c.trfChan.Load().(chan struct{}) ch := c.trfChan.Swap(nil)
if ch != nil { if ch != nil {
close(ch) close(*ch)
} }
c.trfChan.Store((chan struct{})(nil))
c.mu.Unlock()
} }
type side int type side int
@ -466,7 +487,7 @@ const (
// stream maintains a queue of data frames; as loopy receives data frames // stream maintains a queue of data frames; as loopy receives data frames
// it gets added to the queue of the relevant stream. // it gets added to the queue of the relevant stream.
// Loopy goes over this list of active streams by processing one node every iteration, // Loopy goes over this list of active streams by processing one node every iteration,
// thereby closely resemebling to a round-robin scheduling over all streams. While // thereby closely resembling a round-robin scheduling over all streams. While
// processing a stream, loopy writes out data bytes from this stream capped by the min // processing a stream, loopy writes out data bytes from this stream capped by the min
// of http2MaxFrameLen, connection-level flow control and stream-level flow control. // of http2MaxFrameLen, connection-level flow control and stream-level flow control.
type loopyWriter struct { type loopyWriter struct {
@ -490,12 +511,13 @@ type loopyWriter struct {
draining bool draining bool
conn net.Conn conn net.Conn
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
// Side-specific handlers // Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error) ssGoAwayHandler func(*goAway) (bool, error)
} }
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error)) *loopyWriter { func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
var buf bytes.Buffer var buf bytes.Buffer
l := &loopyWriter{ l := &loopyWriter{
side: s, side: s,
@ -511,6 +533,7 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato
conn: conn, conn: conn,
logger: logger, logger: logger,
ssGoAwayHandler: goAwayHandler, ssGoAwayHandler: goAwayHandler,
bufferPool: bufferPool,
} }
return l return l
} }
@ -768,6 +791,11 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
// not be established yet. // not be established yet.
delete(l.estdStreams, c.streamID) delete(l.estdStreams, c.streamID)
str.deleteSelf() str.deleteSelf()
for head := str.itl.dequeueAll(); head != nil; head = head.next {
if df, ok := head.it.(*dataFrame); ok {
_ = df.reader.Close()
}
}
} }
if c.rst { // If RST_STREAM needs to be sent. if c.rst { // If RST_STREAM needs to be sent.
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil { if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
@ -903,16 +931,18 @@ func (l *loopyWriter) processData() (bool, error) {
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream. dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
// A data item is represented by a dataFrame, since it later translates into // A data item is represented by a dataFrame, since it later translates into
// multiple HTTP2 data frames. // multiple HTTP2 data frames.
// Every dataFrame has two buffers; h that keeps grpc-message header and d that is actual data. // Every dataFrame has two buffers; h that keeps grpc-message header and data
// As an optimization to keep wire traffic low, data from d is copied to h to make as big as the // that is the actual message. As an optimization to keep wire traffic low, data
// maximum possible HTTP2 frame size. // from data is copied to h to make as big as the maximum possible HTTP2 frame
// size.
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame if len(dataItem.h) == 0 && dataItem.reader.Remaining() == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true // Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err return false, err
} }
str.itl.dequeue() // remove the empty data item from stream str.itl.dequeue() // remove the empty data item from stream
_ = dataItem.reader.Close()
if str.itl.isEmpty() { if str.itl.isEmpty() {
str.state = empty str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
@ -927,9 +957,7 @@ func (l *loopyWriter) processData() (bool, error) {
} }
return false, nil return false, nil
} }
var (
buf []byte
)
// Figure out the maximum size we can send // Figure out the maximum size we can send
maxSize := http2MaxFrameLen maxSize := http2MaxFrameLen
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control. if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
@ -943,43 +971,50 @@ func (l *loopyWriter) processData() (bool, error) {
} }
// Compute how much of the header and data we can send within quota and max frame length // Compute how much of the header and data we can send within quota and max frame length
hSize := min(maxSize, len(dataItem.h)) hSize := min(maxSize, len(dataItem.h))
dSize := min(maxSize-hSize, len(dataItem.d)) dSize := min(maxSize-hSize, dataItem.reader.Remaining())
if hSize != 0 { remainingBytes := len(dataItem.h) + dataItem.reader.Remaining() - hSize - dSize
if dSize == 0 {
buf = dataItem.h
} else {
// We can add some data to grpc message header to distribute bytes more equally across frames.
// Copy on the stack to avoid generating garbage
var localBuf [http2MaxFrameLen]byte
copy(localBuf[:hSize], dataItem.h)
copy(localBuf[hSize:], dataItem.d[:dSize])
buf = localBuf[:hSize+dSize]
}
} else {
buf = dataItem.d
}
size := hSize + dSize size := hSize + dSize
var buf *[]byte
if hSize != 0 && dSize == 0 {
buf = &dataItem.h
} else {
// Note: this is only necessary because the http2.Framer does not support
// partially writing a frame, so the sequence must be materialized into a buffer.
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed.
pool := l.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
buf = pool.Get(size)
defer pool.Put(buf)
copy((*buf)[:hSize], dataItem.h)
_, _ = dataItem.reader.Read((*buf)[hSize:])
}
// Now that outgoing flow controls are checked we can replenish str's write quota // Now that outgoing flow controls are checked we can replenish str's write quota
str.wq.replenish(size) str.wq.replenish(size)
var endStream bool var endStream bool
// If this is the last data message on this stream and all of it can be written in this iteration. // If this is the last data message on this stream and all of it can be written in this iteration.
if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size { if dataItem.endStream && remainingBytes == 0 {
endStream = true endStream = true
} }
if dataItem.onEachWrite != nil { if dataItem.onEachWrite != nil {
dataItem.onEachWrite() dataItem.onEachWrite()
} }
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil {
return false, err return false, err
} }
str.bytesOutStanding += size str.bytesOutStanding += size
l.sendQuota -= uint32(size) l.sendQuota -= uint32(size)
dataItem.h = dataItem.h[hSize:] dataItem.h = dataItem.h[hSize:]
dataItem.d = dataItem.d[dSize:]
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. if remainingBytes == 0 { // All the data from that message was written out.
_ = dataItem.reader.Close()
str.itl.dequeue() str.itl.dequeue()
} }
if str.itl.isEmpty() { if str.itl.isEmpty() {

View File

@ -24,7 +24,6 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -40,6 +39,7 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -50,7 +50,7 @@ import (
// NewServerHandlerTransport returns a ServerTransport handling gRPC from // NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error. // inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2. // It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) { func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost) w.Header().Set("Allow", http.MethodPost)
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
@ -98,6 +98,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
contentType: contentType, contentType: contentType,
contentSubtype: contentSubtype, contentSubtype: contentSubtype,
stats: stats, stats: stats,
bufferPool: bufferPool,
} }
st.logger = prefixLoggerForServerHandlerTransport(st) st.logger = prefixLoggerForServerHandlerTransport(st)
@ -171,6 +172,8 @@ type serverHandlerTransport struct {
stats []stats.Handler stats []stats.Handler
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
} }
func (ht *serverHandlerTransport) Close(err error) { func (ht *serverHandlerTransport) Close(err error) {
@ -244,6 +247,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
} }
s.hdrMu.Lock() s.hdrMu.Lock()
defer s.hdrMu.Unlock()
if p := st.Proto(); p != nil && len(p.Details) > 0 { if p := st.Proto(); p != nil && len(p.Details) > 0 {
delete(s.trailer, grpcStatusDetailsBinHeader) delete(s.trailer, grpcStatusDetailsBinHeader)
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
@ -268,7 +272,6 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
} }
} }
} }
s.hdrMu.Unlock()
}) })
if err == nil { // transport has not been closed if err == nil { // transport has not been closed
@ -330,16 +333,28 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
s.hdrMu.Unlock() s.hdrMu.Unlock()
} }
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
// Always take a reference because otherwise there is no guarantee the data will
// be available after this function returns. This is what callers to Write
// expect.
data.Ref()
headersWritten := s.updateHeaderSent() headersWritten := s.updateHeaderSent()
return ht.do(func() { err := ht.do(func() {
defer data.Free()
if !headersWritten { if !headersWritten {
ht.writePendingHeaders(s) ht.writePendingHeaders(s)
} }
ht.rw.Write(hdr) ht.rw.Write(hdr)
ht.rw.Write(data) for _, b := range data {
_, _ = ht.rw.Write(b.ReadOnlyData())
}
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
}) })
if err != nil {
data.Free()
return err
}
return nil
} }
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
@ -406,7 +421,7 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
headerWireLength: 0, // won't have access to header wire length until golang/go#18997. headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
@ -415,21 +430,19 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
go func() { go func() {
defer close(readerDone) defer close(readerDone)
// TODO: minimize garbage, optimize recvBuffer code/ownership for {
const readSize = 8196 buf := ht.bufferPool.Get(http2MaxFrameLen)
for buf := make([]byte, readSize); ; { n, err := req.Body.Read(*buf)
n, err := req.Body.Read(buf)
if n > 0 { if n > 0 {
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])}) *buf = (*buf)[:n]
buf = buf[n:] s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
} else {
ht.bufferPool.Put(buf)
} }
if err != nil { if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)}) s.buf.put(recvMsg{err: mapRecvMsgError(err)})
return return
} }
if len(buf) == 0 {
buf = make([]byte, readSize)
}
} }
}() }()

View File

@ -47,6 +47,7 @@ import (
isyscall "google.golang.org/grpc/internal/syscall" isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -59,6 +60,8 @@ import (
// atomically. // atomically.
var clientConnectionCounter uint64 var clientConnectionCounter uint64
var goAwayLoopyWriterTimeout = 5 * time.Second
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
// http2Client implements the ClientTransport interface with HTTP2. // http2Client implements the ClientTransport interface with HTTP2.
@ -144,7 +147,7 @@ type http2Client struct {
onClose func(GoAwayReason) onClose func(GoAwayReason)
bufferPool *bufferPool bufferPool mem.BufferPool
connectionID uint64 connectionID uint64
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
@ -229,7 +232,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
} }
}(conn) }(conn)
// The following defer and goroutine monitor the connectCtx for cancelation // The following defer and goroutine monitor the connectCtx for cancellation
// and deadline. On context expiration, the connection is hard closed and // and deadline. On context expiration, the connection is hard closed and
// this function will naturally fail as a result. Otherwise, the defer // this function will naturally fail as a result. Otherwise, the defer
// waits for the goroutine to exit to prevent the context from being // waits for the goroutine to exit to prevent the context from being
@ -346,7 +349,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
streamQuota: defaultMaxStreamsClient, streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1), streamsQuotaAvailable: make(chan struct{}, 1),
keepaliveEnabled: keepaliveEnabled, keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(), bufferPool: opts.BufferPool,
onClose: onClose, onClose: onClose,
} }
var czSecurity credentials.ChannelzSecurityValue var czSecurity credentials.ChannelzSecurityValue
@ -463,7 +466,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return nil, err return nil, err
} }
go func() { go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler) t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
if err := t.loopy.run(); !isIOError(err) { if err := t.loopy.run(); !isIOError(err) {
// Immediately close the connection, as the loopy writer returns // Immediately close the connection, as the loopy writer returns
// when there are no more active streams and we were draining (the // when there are no more active streams and we were draining (the
@ -504,7 +507,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
closeStream: func(err error) { closeStream: func(err error) {
t.CloseStream(s, err) t.CloseStream(s, err)
}, },
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -983,6 +985,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// only once on a transport. Once it is called, the transport should not be // only once on a transport. Once it is called, the transport should not be
// accessed anymore. // accessed anymore.
func (t *http2Client) Close(err error) { func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
t.mu.Lock() t.mu.Lock()
// Make sure we only close once. // Make sure we only close once.
if t.state == closing { if t.state == closing {
@ -1006,10 +1009,20 @@ func (t *http2Client) Close(err error) {
t.kpDormancyCond.Signal() t.kpDormancyCond.Signal()
} }
t.mu.Unlock() t.mu.Unlock()
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It
// also waits for loopyWriter to be closed with a timer to avoid the
// long blocking in case the connection is blackholed, i.e. TCP is
// just stuck.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
<-t.writerDone timer := time.NewTimer(goAwayLoopyWriterTimeout)
defer timer.Stop()
select {
case <-t.writerDone: // success
case <-timer.C:
t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout)
}
t.cancel() t.cancel()
t.conn.Close() t.conn.Close()
channelz.RemoveEntry(t.channelz.ID) channelz.RemoveEntry(t.channelz.ID)
@ -1065,27 +1078,36 @@ func (t *http2Client) GracefulClose() {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Client) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
reader := data.Reader()
if opts.Last { if opts.Last {
// If it's the last message, update stream state. // If it's the last message, update stream state.
if !s.compareAndSwapState(streamActive, streamWriteDone) { if !s.compareAndSwapState(streamActive, streamWriteDone) {
_ = reader.Close()
return errStreamDone return errStreamDone
} }
} else if s.getState() != streamActive { } else if s.getState() != streamActive {
_ = reader.Close()
return errStreamDone return errStreamDone
} }
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
endStream: opts.Last, endStream: opts.Last,
h: hdr, h: hdr,
d: data, reader: reader,
} }
if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if hdr != nil || df.reader.Remaining() != 0 { // If it's not an empty data frame, check quota.
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return err return err
} }
} }
return t.controlBuf.put(df) if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
return nil
} }
func (t *http2Client) getStream(f http2.Frame) *Stream { func (t *http2Client) getStream(f http2.Frame) *Stream {
@ -1190,10 +1212,13 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
buffer := t.bufferPool.get() pool := t.bufferPool
buffer.Reset() if pool == nil {
buffer.Write(f.Data()) // Note that this is only supposed to be nil in tests. Otherwise, stream is
s.write(recvMsg{buffer: buffer}) // always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
@ -1222,7 +1247,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if statusCode == codes.Canceled { if statusCode == codes.Canceled {
if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) { if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
// Our deadline was already exceeded, and that was likely the cause // Our deadline was already exceeded, and that was likely the cause
// of this cancelation. Alter the status code accordingly. // of this cancellation. Alter the status code accordingly.
statusCode = codes.DeadlineExceeded statusCode = codes.DeadlineExceeded
} }
} }
@ -1307,7 +1332,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID id := f.LastStreamID
if id > 0 && id%2 == 0 { if id > 0 && id%2 == 0 {
t.mu.Unlock() t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id))
return return
} }
// A client can receive multiple GoAways from the server (see // A client can receive multiple GoAways from the server (see

View File

@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -119,7 +120,7 @@ type http2Server struct {
// Fields below are for channelz metric collection. // Fields below are for channelz metric collection.
channelz *channelz.Socket channelz *channelz.Socket
bufferPool *bufferPool bufferPool mem.BufferPool
connectionID uint64 connectionID uint64
@ -261,7 +262,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
idle: time.Now(), idle: time.Now(),
kep: kep, kep: kep,
initialWindowSize: iwz, initialWindowSize: iwz,
bufferPool: newBufferPool(), bufferPool: config.BufferPool,
} }
var czSecurity credentials.ChannelzSecurityValue var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok { if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
@ -330,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.handleSettings(sf) t.handleSettings(sf)
go func() { go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler) t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
err := t.loopy.run() err := t.loopy.run()
close(t.loopyWriterDone) close(t.loopyWriterDone)
if !isIOError(err) { if !isIOError(err) {
@ -613,10 +614,9 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctxDone, ctxDone: s.ctxDone,
recv: s.buf, recv: s.buf,
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -813,10 +813,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
buffer := t.bufferPool.get() pool := t.bufferPool
buffer.Reset() if pool == nil {
buffer.Write(f.Data()) // Note that this is only supposed to be nil in tests. Otherwise, stream is
s.write(recvMsg{buffer: buffer}) // always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
if f.StreamEnded() { if f.StreamEnded() {
@ -1089,7 +1092,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
onWrite: t.setResetPingStrikes, onWrite: t.setResetPingStrikes,
} }
success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader)
}, nil)
if !success { if !success {
if err != nil { if err != nil {
return err return err
@ -1112,27 +1117,37 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Server) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
reader := data.Reader()
if !s.isHeaderSent() { // Headers haven't been written yet. if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil { if err := t.WriteHeader(s, nil); err != nil {
_ = reader.Close()
return err return err
} }
} else { } else {
// Writing headers checks for this condition. // Writing headers checks for this condition.
if s.getState() == streamDone { if s.getState() == streamDone {
_ = reader.Close()
return t.streamContextErr(s) return t.streamContextErr(s)
} }
} }
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
h: hdr, h: hdr,
d: data, reader: reader,
onEachWrite: t.setResetPingStrikes, onEachWrite: t.setResetPingStrikes,
} }
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return t.streamContextErr(s) return t.streamContextErr(s)
} }
return t.controlBuf.put(df) if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
return nil
} }
// keepalive running in a separate goroutine does the following: // keepalive running in a separate goroutine does the following:

View File

@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
return w return w
} }
func (w *bufWriter) Write(b []byte) (n int, err error) { func (w *bufWriter) Write(b []byte) (int, error) {
if w.err != nil { if w.err != nil {
return 0, w.err return 0, w.err
} }
if w.batchSize == 0 { // Buffer has been disabled. if w.batchSize == 0 { // Buffer has been disabled.
n, err = w.conn.Write(b) n, err := w.conn.Write(b)
return n, toIOError(err) return n, toIOError(err)
} }
if w.buf == nil { if w.buf == nil {
b := w.pool.Get().(*[]byte) b := w.pool.Get().(*[]byte)
w.buf = *b w.buf = *b
} }
written := 0
for len(b) > 0 { for len(b) > 0 {
nn := copy(w.buf[w.offset:], b) copied := copy(w.buf[w.offset:], b)
b = b[nn:] b = b[copied:]
w.offset += nn written += copied
n += nn w.offset += copied
if w.offset >= w.batchSize { if w.offset < w.batchSize {
err = w.flushKeepBuffer() continue
}
if err := w.flushKeepBuffer(); err != nil {
return written, err
} }
} }
return n, err return written, nil
} }
func (w *bufWriter) Flush() error { func (w *bufWriter) Flush() error {

View File

@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
} }
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
} }
// The buffer could contain extra bytes from the target server, so we can't
return &bufConn{Conn: conn, r: r}, nil // discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
} }
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy // proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy

View File

@ -22,7 +22,6 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -37,6 +36,7 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -47,32 +47,10 @@ import (
const logLevel = 2 const logLevel = 2
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport // recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed. // protocol specific info has been removed.
type recvMsg struct { type recvMsg struct {
buffer *bytes.Buffer buffer mem.Buffer
// nil: received some data // nil: received some data
// io.EOF: stream is completed. data is nil. // io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil. // other non-nil error: transport failure. data is nil.
@ -102,6 +80,9 @@ func newRecvBuffer() *recvBuffer {
func (b *recvBuffer) put(r recvMsg) { func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock() b.mu.Lock()
if b.err != nil { if b.err != nil {
// drop the buffer on the floor. Since b.err is not nil, any subsequent reads
// will always return an error, making this buffer inaccessible.
r.buffer.Free()
b.mu.Unlock() b.mu.Unlock()
// An error had occurred earlier, don't accept more // An error had occurred earlier, don't accept more
// data or errors. // data or errors.
@ -148,45 +129,70 @@ type recvBufferReader struct {
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance). ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
last *bytes.Buffer // Stores the remaining data in the previous calls. last mem.Buffer // Stores the remaining data in the previous calls.
err error err error
freeBuffer func(*bytes.Buffer)
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to func (r *recvBufferReader) ReadHeader(header []byte) (n int, err error) {
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }
if r.last != nil { if r.last != nil {
// Read remaining data left in last call. n, r.last = mem.ReadUnsafe(header, r.last)
copied, _ := r.last.Read(p) return n, nil
if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil
} }
if r.closeStream != nil { if r.closeStream != nil {
n, r.err = r.readClient(p) n, r.err = r.readHeaderClient(header)
} else { } else {
n, r.err = r.read(p) n, r.err = r.readHeader(header)
} }
return n, r.err return n, r.err
} }
func (r *recvBufferReader) read(p []byte) (n int, err error) { // Read reads the next n bytes from last. If last is drained, it tries to read
// additional data from recv. It blocks if there no additional data available in
// recv. If Read returns any non-nil error, it will continue to return that
// error.
func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
if r.err != nil {
return nil, r.err
}
if r.last != nil {
buf = r.last
if r.last.Len() > n {
buf, r.last = mem.SplitUnsafe(buf, n)
} else {
r.last = nil
}
return buf, nil
}
if r.closeStream != nil {
buf, r.err = r.readClient(n)
} else {
buf, r.err = r.read(n)
}
return buf, r.err
}
func (r *recvBufferReader) readHeader(header []byte) (n int, err error) {
select { select {
case <-r.ctxDone: case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get(): case m := <-r.recv.get():
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
} }
} }
func (r *recvBufferReader) readClient(p []byte) (n int, err error) { func (r *recvBufferReader) read(n int) (buf mem.Buffer, err error) {
select {
case <-r.ctxDone:
return nil, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readHeaderClient(header []byte) (n int, err error) {
// If the context is canceled, then closes the stream with nil metadata. // If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg. // closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error. // r.readAdditional acts on that message and returns the necessary error.
@ -207,25 +213,67 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
// faster. // faster.
r.closeStream(ContextErr(r.ctx.Err())) r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get() m := <-r.recv.get()
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
case m := <-r.recv.get(): case m := <-r.recv.get():
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
} }
} }
func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) { func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error.
select {
case <-r.ctxDone:
// Note that this adds the ctx error to the end of recv buffer, and
// reads from the head. This will delay the error until recv buffer is
// empty, thus will delay ctx cancellation in Recv().
//
// It's done this way to fix a race between ctx cancel and trailer. The
// race was, stream.Recv() may return ctx error if ctxDone wins the
// race, but stream.Trailer() may return a non-nil md because the stream
// was not marked as done when trailer is received. This closeStream
// call will mark stream as done, thus fix the race.
//
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, n)
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readHeaderAdditional(m recvMsg, header []byte) (n int, err error) {
r.recv.load() r.recv.load()
if m.err != nil { if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
}
return 0, m.err return 0, m.err
} }
copied, _ := m.buffer.Read(p)
if m.buffer.Len() == 0 { n, r.last = mem.ReadUnsafe(header, m.buffer)
r.freeBuffer(m.buffer)
r.last = nil return n, nil
} else { }
r.last = m.buffer
func (r *recvBufferReader) readAdditional(m recvMsg, n int) (b mem.Buffer, err error) {
r.recv.load()
if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
}
return nil, m.err
} }
return copied, nil
if m.buffer.Len() > n {
m.buffer, r.last = mem.SplitUnsafe(m.buffer, n)
}
return m.buffer, nil
} }
type streamState uint32 type streamState uint32
@ -241,7 +289,7 @@ const (
type Stream struct { type Stream struct {
id uint32 id uint32
st ServerTransport // nil for client side Stream st ServerTransport // nil for client side Stream
ct *http2Client // nil for server side Stream ct ClientTransport // nil for server side Stream
ctx context.Context // the associated context of the stream ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side. done chan struct{} // closed at the end of stream to unblock writers. On the client side.
@ -251,7 +299,7 @@ type Stream struct {
recvCompress string recvCompress string
sendCompress string sendCompress string
buf *recvBuffer buf *recvBuffer
trReader io.Reader trReader *transportReader
fc *inFlow fc *inFlow
wq *writeQuota wq *writeQuota
@ -408,7 +456,7 @@ func (s *Stream) TrailersOnly() bool {
return s.noHeaders return s.noHeaders
} }
// Trailer returns the cached trailer metedata. Note that if it is not called // Trailer returns the cached trailer metadata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client // after the entire stream is done, it could return an empty MD. Client
// side only. // side only.
// It can be safely read only after stream has ended that is either read // It can be safely read only after stream has ended that is either read
@ -499,36 +547,87 @@ func (s *Stream) write(m recvMsg) {
s.buf.put(m) s.buf.put(m)
} }
// Read reads all p bytes from the wire for this stream. func (s *Stream) ReadHeader(header []byte) (err error) {
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier // Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil { if er := s.trReader.er; er != nil {
return 0, er return er
} }
s.requestRead(len(p)) s.requestRead(len(header))
return io.ReadFull(s.trReader, p) for len(header) != 0 {
n, err := s.trReader.ReadHeader(header)
header = header[n:]
if len(header) == 0 {
err = nil
}
if err != nil {
if n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
}
return nil
} }
// tranportReader reads all the data available for this Stream from the transport and // Read reads n bytes from the wire for this stream.
func (s *Stream) Read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil {
return nil, er
}
s.requestRead(n)
for n != 0 {
buf, err := s.trReader.Read(n)
var bufLen int
if buf != nil {
bufLen = buf.Len()
}
n -= bufLen
if n == 0 {
err = nil
}
if err != nil {
if bufLen > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
data.Free()
return nil, err
}
data = append(data, buf)
}
return data, nil
}
// transportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream. // passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if // The error is io.EOF when the stream is done or another non-nil error if
// the stream broke. // the stream broke.
type transportReader struct { type transportReader struct {
reader io.Reader reader *recvBufferReader
// The handler to control the window update procedure for both this // The handler to control the window update procedure for both this
// particular stream and the associated transport. // particular stream and the associated transport.
windowHandler func(int) windowHandler func(int)
er error er error
} }
func (t *transportReader) Read(p []byte) (n int, err error) { func (t *transportReader) ReadHeader(header []byte) (int, error) {
n, err = t.reader.Read(p) n, err := t.reader.ReadHeader(header)
if err != nil { if err != nil {
t.er = err t.er = err
return return 0, err
} }
t.windowHandler(n) t.windowHandler(len(header))
return return n, nil
}
func (t *transportReader) Read(n int) (mem.Buffer, error) {
buf, err := t.reader.Read(n)
if err != nil {
t.er = err
return buf, err
}
t.windowHandler(buf.Len())
return buf, nil
} }
// BytesReceived indicates whether any bytes have been received on this stream. // BytesReceived indicates whether any bytes have been received on this stream.
@ -574,6 +673,7 @@ type ServerConfig struct {
ChannelzParent *channelz.Server ChannelzParent *channelz.Server
MaxHeaderListSize *uint32 MaxHeaderListSize *uint32
HeaderTableSize *uint32 HeaderTableSize *uint32
BufferPool mem.BufferPool
} }
// ConnectOptions covers all relevant options for communicating with the server. // ConnectOptions covers all relevant options for communicating with the server.
@ -612,6 +712,8 @@ type ConnectOptions struct {
MaxHeaderListSize *uint32 MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used. // UseProxy specifies if a proxy should be used.
UseProxy bool UseProxy bool
// The mem.BufferPool to use when reading/writing to the wire.
BufferPool mem.BufferPool
} }
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
@ -673,7 +775,7 @@ type ClientTransport interface {
// Write sends the data for the given stream. A nil stream indicates // Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole. // the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
// NewStream creates a Stream for an RPC. // NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@ -725,7 +827,7 @@ type ServerTransport interface {
// Write sends the data for the given stream. // Write sends the data for the given stream.
// Write may not be called on all streams. // Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is // WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs. // the final call made on a stream and always occurs.
@ -798,7 +900,7 @@ var (
// connection is draining. This could be caused by goaway or balancer // connection is draining. This could be caused by goaway or balancer
// removing the address. // removing the address.
errStreamDrain = status.Error(codes.Unavailable, "the connection is draining") errStreamDrain = status.Error(codes.Unavailable, "the connection is draining")
// errStreamDone is returned from write at the client side to indiacte application // errStreamDone is returned from write at the client side to indicate application
// layer of an error. // layer of an error.
errStreamDone = errors.New("the stream is done") errStreamDone = errors.New("the stream is done")
// StatusGoAway indicates that the server sent a GOAWAY that included this // StatusGoAway indicates that the server sent a GOAWAY that included this

194
vendor/google.golang.org/grpc/mem/buffer_pool.go generated vendored Normal file
View File

@ -0,0 +1,194 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"sort"
"sync"
"google.golang.org/grpc/internal"
)
// BufferPool is a pool of buffers that can be shared and reused, resulting in
// decreased memory allocation.
type BufferPool interface {
// Get returns a buffer with specified length from the pool.
Get(length int) *[]byte
// Put returns a buffer to the pool.
Put(*[]byte)
}
var defaultBufferPoolSizes = []int{
256,
4 << 10, // 4KB (go page size)
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB
}
var defaultBufferPool BufferPool
func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
defaultBufferPool = pool
}
internal.SetBufferPoolingThresholdForTesting = func(threshold int) {
bufferPoolingThreshold = threshold
}
}
// DefaultBufferPool returns the current default buffer pool. It is a BufferPool
// created with NewBufferPool that uses a set of default sizes optimized for
// expected workflows.
func DefaultBufferPool() BufferPool {
return defaultBufferPool
}
// NewTieredBufferPool returns a BufferPool implementation that uses multiple
// underlying pools of the given pool sizes.
func NewTieredBufferPool(poolSizes ...int) BufferPool {
sort.Ints(poolSizes)
pools := make([]*sizedBufferPool, len(poolSizes))
for i, s := range poolSizes {
pools[i] = newSizedBufferPool(s)
}
return &tieredBufferPool{
sizedPools: pools,
}
}
// tieredBufferPool implements the BufferPool interface with multiple tiers of
// buffer pools for different sizes of buffers.
type tieredBufferPool struct {
sizedPools []*sizedBufferPool
fallbackPool simpleBufferPool
}
func (p *tieredBufferPool) Get(size int) *[]byte {
return p.getPool(size).Get(size)
}
func (p *tieredBufferPool) Put(buf *[]byte) {
p.getPool(cap(*buf)).Put(buf)
}
func (p *tieredBufferPool) getPool(size int) BufferPool {
poolIdx := sort.Search(len(p.sizedPools), func(i int) bool {
return p.sizedPools[i].defaultSize >= size
})
if poolIdx == len(p.sizedPools) {
return &p.fallbackPool
}
return p.sizedPools[poolIdx]
}
// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
// capacity of 16kb. Note that however it does not support returning larger
// buffers and in fact panics if such a buffer is requested. Because of this,
// this BufferPool implementation is not meant to be used on its own and rather
// is intended to be embedded in a tieredBufferPool such that Get is only
// invoked when the required size is smaller than or equal to defaultSize.
type sizedBufferPool struct {
pool sync.Pool
defaultSize int
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf := p.pool.Get().(*[]byte)
b := *buf
clear(b[:cap(b)])
*buf = b[:size]
return buf
}
func (p *sizedBufferPool) Put(buf *[]byte) {
if cap(*buf) < p.defaultSize {
// Ignore buffers that are too small to fit in the pool. Otherwise, when
// Get is called it will panic as it tries to index outside the bounds
// of the buffer.
return
}
p.pool.Put(buf)
}
func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{
pool: sync.Pool{
New: func() any {
buf := make([]byte, size)
return &buf
},
},
defaultSize: size,
}
}
var _ BufferPool = (*simpleBufferPool)(nil)
// simpleBufferPool is an implementation of the BufferPool interface that
// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to
// acquire a buffer from the pool but if that buffer is too small, it returns it
// to the pool and creates a new one.
type simpleBufferPool struct {
pool sync.Pool
}
func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
*bs = (*bs)[:size]
return bs
}
// A buffer was pulled from the pool, but it is too small. Put it back in
// the pool and create one large enough.
if ok {
p.pool.Put(bs)
}
b := make([]byte, size)
return &b
}
func (p *simpleBufferPool) Put(buf *[]byte) {
p.pool.Put(buf)
}
var _ BufferPool = NopBufferPool{}
// NopBufferPool is a buffer pool that returns new buffers without pooling.
type NopBufferPool struct{}
// Get returns a buffer with specified length from the pool.
func (NopBufferPool) Get(length int) *[]byte {
b := make([]byte, length)
return &b
}
// Put returns a buffer to the pool.
func (NopBufferPool) Put(*[]byte) {
}

224
vendor/google.golang.org/grpc/mem/buffer_slice.go generated vendored Normal file
View File

@ -0,0 +1,224 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"compress/flate"
"io"
)
// BufferSlice offers a means to represent data that spans one or more Buffer
// instances. A BufferSlice is meant to be immutable after creation, and methods
// like Ref create and return copies of the slice. This is why all methods have
// value receivers rather than pointer receivers.
//
// Note that any of the methods that read the underlying buffers such as Ref,
// Len or CopyTo etc., will panic if any underlying buffers have already been
// freed. It is recommended to not directly interact with any of the underlying
// buffers directly, rather such interactions should be mediated through the
// various methods on this type.
//
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
// the burden on the caller by never returning a mem.BufferSlice that needs to
// be freed if the error is non-nil, unless explicitly stated.
type BufferSlice []Buffer
// Len returns the sum of the length of all the Buffers in this slice.
//
// # Warning
//
// Invoking the built-in len on a BufferSlice will return the number of buffers
// in the slice, and *not* the value returned by this function.
func (s BufferSlice) Len() int {
var length int
for _, b := range s {
length += b.Len()
}
return length
}
// Ref invokes Ref on each buffer in the slice.
func (s BufferSlice) Ref() {
for _, b := range s {
b.Ref()
}
}
// Free invokes Buffer.Free() on each Buffer in the slice.
func (s BufferSlice) Free() {
for _, b := range s {
b.Free()
}
}
// CopyTo copies each of the underlying Buffer's data into the given buffer,
// returning the number of bytes copied. Has the same semantics as the copy
// builtin in that it will copy as many bytes as it can, stopping when either dst
// is full or s runs out of data, returning the minimum of s.Len() and len(dst).
func (s BufferSlice) CopyTo(dst []byte) int {
off := 0
for _, b := range s {
off += copy(dst[off:], b.ReadOnlyData())
}
return off
}
// Materialize concatenates all the underlying Buffer's data into a single
// contiguous buffer using CopyTo.
func (s BufferSlice) Materialize() []byte {
l := s.Len()
if l == 0 {
return nil
}
out := make([]byte, l)
s.CopyTo(out)
return out
}
// MaterializeToBuffer functions like Materialize except that it writes the data
// to a single Buffer pulled from the given BufferPool. As a special case, if the
// input BufferSlice only actually has one Buffer, this function has nothing to
// do and simply returns said Buffer.
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
if len(s) == 1 {
s[0].Ref()
return s[0]
}
sLen := s.Len()
if sLen == 0 {
return emptyBuffer{}
}
buf := pool.Get(sLen)
s.CopyTo(*buf)
return NewBuffer(buf, pool)
}
// Reader returns a new Reader for the input slice after taking references to
// each underlying buffer.
func (s BufferSlice) Reader() Reader {
s.Ref()
return &sliceReader{
data: s,
len: s.Len(),
}
}
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
// with other parts systems. It also provides an additional convenience method
// Remaining(), which returns the number of unread bytes remaining in the slice.
// Buffers will be freed as they are read.
type Reader interface {
flate.Reader
// Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
Close() error
// Remaining returns the number of unread bytes remaining in the slice.
Remaining() int
}
type sliceReader struct {
data BufferSlice
len int
// The index into data[0].ReadOnlyData().
bufferIdx int
}
func (r *sliceReader) Remaining() int {
return r.len
}
func (r *sliceReader) Close() error {
r.data.Free()
r.data = nil
r.len = 0
return nil
}
func (r *sliceReader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
return false
}
r.data[0].Free()
r.data = r.data[1:]
r.bufferIdx = 0
return true
}
func (r *sliceReader) Read(buf []byte) (n int, _ error) {
if r.len == 0 {
return 0, io.EOF
}
for len(buf) != 0 && r.len != 0 {
// Copy as much as possible from the first Buffer in the slice into the
// given byte slice.
data := r.data[0].ReadOnlyData()
copied := copy(buf, data[r.bufferIdx:])
r.len -= copied // Reduce len by the number of bytes copied.
r.bufferIdx += copied // Increment the buffer index.
n += copied // Increment the total number of bytes read.
buf = buf[copied:] // Shrink the given byte slice.
// If we have copied all the data from the first Buffer, free it and advance to
// the next in the slice.
r.freeFirstBufferIfEmpty()
}
return n, nil
}
func (r *sliceReader) ReadByte() (byte, error) {
if r.len == 0 {
return 0, io.EOF
}
// There may be any number of empty buffers in the slice, clear them all until a
// non-empty buffer is reached. This is guaranteed to exit since r.len is not 0.
for r.freeFirstBufferIfEmpty() {
}
b := r.data[0].ReadOnlyData()[r.bufferIdx]
r.len--
r.bufferIdx++
// Free the first buffer in the slice if the last byte was read
r.freeFirstBufferIfEmpty()
return b, nil
}
var _ io.Writer = (*writer)(nil)
type writer struct {
buffers *BufferSlice
pool BufferPool
}
func (w *writer) Write(p []byte) (n int, err error) {
b := Copy(p, w.pool)
*w.buffers = append(*w.buffers, b)
return b.Len(), nil
}
// NewWriter wraps the given BufferSlice and BufferPool to implement the
// io.Writer interface. Every call to Write copies the contents of the given
// buffer into a new Buffer pulled from the given pool and the Buffer is added to
// the given BufferSlice.
func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer {
return &writer{buffers: buffers, pool: pool}
}

252
vendor/google.golang.org/grpc/mem/buffers.go generated vendored Normal file
View File

@ -0,0 +1,252 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package mem provides utilities that facilitate memory reuse in byte slices
// that are used as buffers.
//
// # Experimental
//
// Notice: All APIs in this package are EXPERIMENTAL and may be changed or
// removed in a later release.
package mem
import (
"fmt"
"sync"
"sync/atomic"
)
// A Buffer represents a reference counted piece of data (in bytes) that can be
// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be
// released by calling Free(), which invokes the free function given at creation
// only after all references are released.
//
// Note that a Buffer is not safe for concurrent access and instead each
// goroutine should use its own reference to the data, which can be acquired via
// a call to Ref().
//
// Attempts to access the underlying data after releasing the reference to the
// Buffer will panic.
type Buffer interface {
// ReadOnlyData returns the underlying byte slice. Note that it is undefined
// behavior to modify the contents of this slice in any way.
ReadOnlyData() []byte
// Ref increases the reference counter for this Buffer.
Ref()
// Free decrements this Buffer's reference counter and frees the underlying
// byte slice if the counter reaches 0 as a result of this call.
Free()
// Len returns the Buffer's size.
Len() int
split(n int) (left, right Buffer)
read(buf []byte) (int, Buffer)
}
var (
bufferPoolingThreshold = 1 << 10
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
)
func IsBelowBufferPoolingThreshold(size int) bool {
return size <= bufferPoolingThreshold
}
type buffer struct {
origData *[]byte
data []byte
refs *atomic.Int32
pool BufferPool
}
func newBuffer() *buffer {
return bufferObjectPool.Get().(*buffer)
}
// NewBuffer creates a new Buffer from the given data, initializing the reference
// counter to 1. The data will then be returned to the given pool when all
// references to the returned Buffer are released. As a special case to avoid
// additional allocations, if the given buffer pool is nil, the returned buffer
// will be a "no-op" Buffer where invoking Buffer.Free() does nothing and the
// underlying data is never freed.
//
// Note that the backing array of the given data is not copied.
func NewBuffer(data *[]byte, pool BufferPool) Buffer {
if pool == nil || IsBelowBufferPoolingThreshold(len(*data)) {
return (SliceBuffer)(*data)
}
b := newBuffer()
b.origData = data
b.data = *data
b.pool = pool
b.refs = refObjectPool.Get().(*atomic.Int32)
b.refs.Add(1)
return b
}
// Copy creates a new Buffer from the given data, initializing the reference
// counter to 1.
//
// It acquires a []byte from the given pool and copies over the backing array
// of the given data. The []byte acquired from the pool is returned to the
// pool when all references to the returned Buffer are released.
func Copy(data []byte, pool BufferPool) Buffer {
if IsBelowBufferPoolingThreshold(len(data)) {
buf := make(SliceBuffer, len(data))
copy(buf, data)
return buf
}
buf := pool.Get(len(data))
copy(*buf, data)
return NewBuffer(buf, pool)
}
func (b *buffer) ReadOnlyData() []byte {
if b.refs == nil {
panic("Cannot read freed buffer")
}
return b.data
}
func (b *buffer) Ref() {
if b.refs == nil {
panic("Cannot ref freed buffer")
}
b.refs.Add(1)
}
func (b *buffer) Free() {
if b.refs == nil {
panic("Cannot free freed buffer")
}
refs := b.refs.Add(-1)
switch {
case refs > 0:
return
case refs == 0:
if b.pool != nil {
b.pool.Put(b.origData)
}
refObjectPool.Put(b.refs)
b.origData = nil
b.data = nil
b.refs = nil
b.pool = nil
bufferObjectPool.Put(b)
default:
panic("Cannot free freed buffer")
}
}
func (b *buffer) Len() int {
return len(b.ReadOnlyData())
}
func (b *buffer) split(n int) (Buffer, Buffer) {
if b.refs == nil {
panic("Cannot split freed buffer")
}
b.refs.Add(1)
split := newBuffer()
split.origData = b.origData
split.data = b.data[n:]
split.refs = b.refs
split.pool = b.pool
b.data = b.data[:n]
return b, split
}
func (b *buffer) read(buf []byte) (int, Buffer) {
if b.refs == nil {
panic("Cannot read freed buffer")
}
n := copy(buf, b.data)
if n == len(b.data) {
b.Free()
return n, nil
}
b.data = b.data[n:]
return n, b
}
// String returns a string representation of the buffer. May be used for
// debugging purposes.
func (b *buffer) String() string {
return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData()))
}
func ReadUnsafe(dst []byte, buf Buffer) (int, Buffer) {
return buf.read(dst)
}
// SplitUnsafe modifies the receiver to point to the first n bytes while it
// returns a new reference to the remaining bytes. The returned Buffer functions
// just like a normal reference acquired using Ref().
func SplitUnsafe(buf Buffer, n int) (left, right Buffer) {
return buf.split(n)
}
type emptyBuffer struct{}
func (e emptyBuffer) ReadOnlyData() []byte {
return nil
}
func (e emptyBuffer) Ref() {}
func (e emptyBuffer) Free() {}
func (e emptyBuffer) Len() int {
return 0
}
func (e emptyBuffer) split(n int) (left, right Buffer) {
return e, e
}
func (e emptyBuffer) read(buf []byte) (int, Buffer) {
return 0, e
}
type SliceBuffer []byte
func (s SliceBuffer) ReadOnlyData() []byte { return s }
func (s SliceBuffer) Ref() {}
func (s SliceBuffer) Free() {}
func (s SliceBuffer) Len() int { return len(s) }
func (s SliceBuffer) split(n int) (left, right Buffer) {
return s[:n], s[n:]
}
func (s SliceBuffer) read(buf []byte) (int, Buffer) {
n := copy(buf, s)
if n == len(s) {
return n, nil
}
return n, s[n:]
}

View File

@ -213,11 +213,6 @@ func FromIncomingContext(ctx context.Context) (MD, bool) {
// ValueFromIncomingContext returns the metadata value corresponding to the metadata // ValueFromIncomingContext returns the metadata value corresponding to the metadata
// key from the incoming metadata if it exists. Keys are matched in a case insensitive // key from the incoming metadata if it exists. Keys are matched in a case insensitive
// manner. // manner.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ValueFromIncomingContext(ctx context.Context, key string) []string { func ValueFromIncomingContext(ctx context.Context, key string) []string {
md, ok := ctx.Value(mdIncomingKey{}).(MD) md, ok := ctx.Value(mdIncomingKey{}).(MD)
if !ok { if !ok {
@ -228,7 +223,7 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string {
return copyOf(v) return copyOf(v)
} }
for k, v := range md { for k, v := range md {
// Case insenitive comparison: MD is a map, and there's no guarantee // Case insensitive comparison: MD is a map, and there's no guarantee
// that the MD attached to the context is created using our helper // that the MD attached to the context is created using our helper
// functions. // functions.
if strings.EqualFold(k, key) { if strings.EqualFold(k, key) {

View File

@ -20,6 +20,7 @@ package grpc
import ( import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -31,9 +32,10 @@ import (
// later release. // later release.
type PreparedMsg struct { type PreparedMsg struct {
// Struct for preparing msg before sending them // Struct for preparing msg before sending them
encodedData []byte encodedData mem.BufferSlice
hdr []byte hdr []byte
payload []byte payload mem.BufferSlice
pf payloadFormat
} }
// Encode marshalls and compresses the message using the codec and compressor for the stream. // Encode marshalls and compresses the message using the codec and compressor for the stream.
@ -57,11 +59,27 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
if err != nil { if err != nil {
return err return err
} }
p.encodedData = data
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp) materializedData := data.Materialize()
data.Free()
p.encodedData = mem.BufferSlice{mem.NewBuffer(&materializedData, nil)}
// TODO: it should be possible to grab the bufferPool from the underlying
// stream implementation with a type cast to its actual type (such as
// addrConnStream) and accessing the buffer pool directly.
var compData mem.BufferSlice
compData, p.pf, err = compress(p.encodedData, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp, mem.DefaultBufferPool())
if err != nil { if err != nil {
return err return err
} }
p.hdr, p.payload = msgHeader(data, compData)
if p.pf.isCompressed() {
materializedCompData := compData.Materialize()
compData.Free()
compData = mem.BufferSlice{mem.NewBuffer(&materializedCompData, nil)}
}
p.hdr, p.payload = msgHeader(p.encodedData, compData, p.pf)
return nil return nil
} }

View File

@ -1,123 +0,0 @@
#!/bin/bash
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eu -o pipefail
WORKDIR=$(mktemp -d)
function finish {
rm -rf "$WORKDIR"
}
trap finish EXIT
export GOBIN=${WORKDIR}/bin
export PATH=${GOBIN}:${PATH}
mkdir -p ${GOBIN}
echo "remove existing generated files"
# grpc_testing_not_regenerate/*.pb.go is not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm -f $(find . -name '*.pb.go' | grep -v 'grpc_testing_not_regenerate')
echo "go install google.golang.org/protobuf/cmd/protoc-gen-go"
(cd test/tools && go install google.golang.org/protobuf/cmd/protoc-gen-go)
echo "go install cmd/protoc-gen-go-grpc"
(cd cmd/protoc-gen-go-grpc && go install .)
echo "git clone https://github.com/grpc/grpc-proto"
git clone --quiet https://github.com/grpc/grpc-proto ${WORKDIR}/grpc-proto
echo "git clone https://github.com/protocolbuffers/protobuf"
git clone --quiet https://github.com/protocolbuffers/protobuf ${WORKDIR}/protobuf
# Pull in code.proto as a proto dependency
mkdir -p ${WORKDIR}/googleapis/google/rpc
echo "curl https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto"
curl --silent https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto > ${WORKDIR}/googleapis/google/rpc/code.proto
mkdir -p ${WORKDIR}/out
# Generates sources without the embed requirement
LEGACY_SOURCES=(
${WORKDIR}/grpc-proto/grpc/binlog/v1/binarylog.proto
${WORKDIR}/grpc-proto/grpc/channelz/v1/channelz.proto
${WORKDIR}/grpc-proto/grpc/health/v1/health.proto
${WORKDIR}/grpc-proto/grpc/lb/v1/load_balancer.proto
profiling/proto/service.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1alpha/reflection.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1/reflection.proto
)
# Generates only the new gRPC Service symbols
SOURCES=(
$(git ls-files --exclude-standard --cached --others "*.proto" | grep -v '^profiling/proto/service.proto$')
${WORKDIR}/grpc-proto/grpc/gcp/altscontext.proto
${WORKDIR}/grpc-proto/grpc/gcp/handshaker.proto
${WORKDIR}/grpc-proto/grpc/gcp/transport_security_common.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls_config.proto
${WORKDIR}/grpc-proto/grpc/testing/*.proto
${WORKDIR}/grpc-proto/grpc/core/*.proto
)
# These options of the form 'Mfoo.proto=bar' instruct the codegen to use an
# import path of 'bar' in the generated code when 'foo.proto' is imported in
# one of the sources.
#
# Note that the protos listed here are all for testing purposes. All protos to
# be used externally should have a go_package option (and they don't need to be
# listed here).
OPTS=Mgrpc/core/stats.proto=google.golang.org/grpc/interop/grpc_testing/core,\
Mgrpc/testing/benchmark_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/stats.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/report_qps_scenario_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/messages.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/worker_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/control.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/test.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/payloads.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/empty.proto=google.golang.org/grpc/interop/grpc_testing
for src in ${SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS},use_generic_streams_experimental=true:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
for src in ${LEGACY_SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS},require_unimplemented_servers=false:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
# The go_package option in grpc/lookup/v1/rls.proto doesn't match the
# current location. Move it into the right place.
mkdir -p ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
mv ${WORKDIR}/out/google.golang.org/grpc/lookup/grpc_lookup_v1/* ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
# grpc_testing_not_regenerate/*.pb.go are not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm ${WORKDIR}/out/google.golang.org/grpc/reflection/test/grpc_testing_not_regenerate/*.pb.go
cp -R ${WORKDIR}/out/google.golang.org/grpc/* .

View File

@ -66,7 +66,7 @@ func newCCResolverWrapper(cc *ClientConn) *ccResolverWrapper {
// any newly created ccResolverWrapper, except that close may be called instead. // any newly created ccResolverWrapper, except that close may be called instead.
func (ccr *ccResolverWrapper) start() error { func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error) errCh := make(chan error)
ccr.serializer.Schedule(func(ctx context.Context) { ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@ -85,7 +85,7 @@ func (ccr *ccResolverWrapper) start() error {
} }
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) { func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.serializer.Schedule(func(ctx context.Context) { ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccr.resolver == nil { if ctx.Err() != nil || ccr.resolver == nil {
return return
} }
@ -102,7 +102,7 @@ func (ccr *ccResolverWrapper) close() {
ccr.closed = true ccr.closed = true
ccr.mu.Unlock() ccr.mu.Unlock()
ccr.serializer.Schedule(func(context.Context) { ccr.serializer.TrySchedule(func(context.Context) {
if ccr.resolver == nil { if ccr.resolver == nil {
return return
} }
@ -177,6 +177,9 @@ func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.P
// addChannelzTraceEvent adds a channelz trace event containing the new // addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations. // state received from resolver implementations.
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) { func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
if !logger.V(0) && !channelz.IsOn() {
return
}
var updates []string var updates []string
var oldSC, newSC *ServiceConfig var oldSC, newSC *ServiceConfig
var oldOK, newOK bool var oldOK, newOK bool

View File

@ -19,7 +19,6 @@
package grpc package grpc
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"context" "context"
"encoding/binary" "encoding/binary"
@ -35,6 +34,7 @@ import (
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -271,17 +271,13 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
} }
} }
// WaitForReady configures the action to take when an RPC is attempted on broken // WaitForReady configures the RPC's behavior when the client is in
// connections or unreachable servers. If waitForReady is false and the // TRANSIENT_FAILURE, which occurs when all addresses fail to connect. If
// connection is in the TRANSIENT_FAILURE state, the RPC will fail // waitForReady is false, the RPC will fail immediately. Otherwise, the client
// immediately. Otherwise, the RPC client will block the call until a // will wait until a connection becomes available or the RPC's deadline is
// connection is available (or the call is canceled or times out) and will // reached.
// retry the call if it fails due to a transient error. gRPC will not retry if
// data was written to the wire unless the server indicates it did not process
// the data. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md.
// //
// By default, RPCs don't "wait for ready". // By default, RPCs do not "wait for ready".
func WaitForReady(waitForReady bool) CallOption { func WaitForReady(waitForReady bool) CallOption {
return FailFastCallOption{FailFast: !waitForReady} return FailFastCallOption{FailFast: !waitForReady}
} }
@ -515,11 +511,51 @@ type ForceCodecCallOption struct {
} }
func (o ForceCodecCallOption) before(c *callInfo) error { func (o ForceCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec c.codec = newCodecV1Bridge(o.Codec)
return nil return nil
} }
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {} func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
// ForceCodecV2 returns a CallOption that will set codec to be used for all
// request and response messages for a call. The result of calling Name() will
// be used as the content-subtype after converting to lowercase, unless
// CallContentSubtype is also used.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceCodecV2(codec encoding.CodecV2) CallOption {
return ForceCodecV2CallOption{CodecV2: codec}
}
// ForceCodecV2CallOption is a CallOption that indicates the codec used for
// marshaling messages.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type ForceCodecV2CallOption struct {
CodecV2 encoding.CodecV2
}
func (o ForceCodecV2CallOption) before(c *callInfo) error {
c.codec = o.CodecV2
return nil
}
func (o ForceCodecV2CallOption) after(c *callInfo, attempt *csAttempt) {}
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of // CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
// an encoding.Codec. // an encoding.Codec.
// //
@ -540,7 +576,7 @@ type CustomCodecCallOption struct {
} }
func (o CustomCodecCallOption) before(c *callInfo) error { func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec c.codec = newCodecV0Bridge(o.Codec)
return nil return nil
} }
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {} func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
@ -581,19 +617,28 @@ const (
compressionMade payloadFormat = 1 // compressed compressionMade payloadFormat = 1 // compressed
) )
func (pf payloadFormat) isCompressed() bool {
return pf == compressionMade
}
type streamReader interface {
ReadHeader(header []byte) error
Read(n int) (mem.BufferSlice, error)
}
// parser reads complete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
type parser struct { type parser struct {
// r is the underlying reader. // r is the underlying reader.
// See the comment on recvMsg for the permissible // See the comment on recvMsg for the permissible
// error types. // error types.
r io.Reader r streamReader
// The header of a gRPC message. Find more detail at // The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte header [5]byte
// recvBufferPool is the pool of shared receive buffers. // bufferPool is the pool of shared receive buffers.
recvBufferPool SharedBufferPool bufferPool mem.BufferPool
} }
// recvMsg reads a complete gRPC message from the stream. // recvMsg reads a complete gRPC message from the stream.
@ -608,14 +653,15 @@ type parser struct {
// - an error from the status package // - an error from the status package
// //
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying streamReader must not return an incompatible
// error. // error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) {
if _, err := p.r.Read(p.header[:]); err != nil { err := p.r.ReadHeader(p.header[:])
if err != nil {
return 0, nil, err return 0, nil, err
} }
pf = payloadFormat(p.header[0]) pf := payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:]) length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 { if length == 0 {
@ -627,20 +673,21 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
if int(length) > maxReceiveMessageSize { if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
} }
msg = p.recvBufferPool.Get(int(length))
if _, err := p.r.Read(msg); err != nil { data, err := p.r.Read(int(length))
if err != nil {
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return 0, nil, err return 0, nil, err
} }
return pf, msg, nil return pf, data, nil
} }
// encode serializes msg and returns a buffer containing the message, or an // encode serializes msg and returns a buffer containing the message, or an
// error if it is too large to be transmitted by grpc. If msg is nil, it // error if it is too large to be transmitted by grpc. If msg is nil, it
// generates an empty message. // generates an empty message.
func encode(c baseCodec, msg any) ([]byte, error) { func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
if msg == nil { // NOTE: typed nils will not be caught by this check if msg == nil { // NOTE: typed nils will not be caught by this check
return nil, nil return nil, nil
} }
@ -648,7 +695,8 @@ func encode(c baseCodec, msg any) ([]byte, error) {
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
} }
if uint(len(b)) > math.MaxUint32 { if uint(b.Len()) > math.MaxUint32 {
b.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
} }
return b, nil return b, nil
@ -659,34 +707,41 @@ func encode(c baseCodec, msg any) ([]byte, error) {
// indicating no compression was done. // indicating no compression was done.
// //
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. // TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) { func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
if compressor == nil && cp == nil { if (compressor == nil && cp == nil) || in.Len() == 0 {
return nil, nil return nil, compressionNone, nil
}
if len(in) == 0 {
return nil, nil
} }
var out mem.BufferSlice
w := mem.NewWriter(&out, pool)
wrapErr := func(err error) error { wrapErr := func(err error) error {
out.Free()
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
} }
cbuf := &bytes.Buffer{}
if compressor != nil { if compressor != nil {
z, err := compressor.Compress(cbuf) z, err := compressor.Compress(w)
if err != nil { if err != nil {
return nil, wrapErr(err) return nil, 0, wrapErr(err)
} }
if _, err := z.Write(in); err != nil { for _, b := range in {
return nil, wrapErr(err) if _, err := z.Write(b.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
}
} }
if err := z.Close(); err != nil { if err := z.Close(); err != nil {
return nil, wrapErr(err) return nil, 0, wrapErr(err)
} }
} else { } else {
if err := cp.Do(cbuf, in); err != nil { // This is obviously really inefficient since it fully materializes the data, but
return nil, wrapErr(err) // there is no way around this with the old Compressor API. At least it attempts
// to return the buffer to the provider, in the hopes it can be reused (maybe
// even by a subsequent call to this very function).
buf := in.MaterializeToBuffer(pool)
defer buf.Free()
if err := cp.Do(w, buf.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
} }
} }
return cbuf.Bytes(), nil return out, compressionMade, nil
} }
const ( const (
@ -697,33 +752,36 @@ const (
// msgHeader returns a 5-byte header for the message being transmitted and the // msgHeader returns a 5-byte header for the message being transmitted and the
// payload, which is compData if non-nil or data otherwise. // payload, which is compData if non-nil or data otherwise.
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) { func msgHeader(data, compData mem.BufferSlice, pf payloadFormat) (hdr []byte, payload mem.BufferSlice) {
hdr = make([]byte, headerLen) hdr = make([]byte, headerLen)
if compData != nil { hdr[0] = byte(pf)
hdr[0] = byte(compressionMade)
data = compData var length uint32
if pf.isCompressed() {
length = uint32(compData.Len())
payload = compData
} else { } else {
hdr[0] = byte(compressionNone) length = uint32(data.Len())
payload = data
} }
// Write length of payload into buf // Write length of payload into buf
binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data))) binary.BigEndian.PutUint32(hdr[payloadLen:], length)
return hdr, data return hdr, payload
} }
func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload { func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time) *stats.OutPayload {
return &stats.OutPayload{ return &stats.OutPayload{
Client: client, Client: client,
Payload: msg, Payload: msg,
Data: data, Length: dataLength,
Length: len(data), WireLength: payloadLength + headerLen,
WireLength: len(payload) + headerLen, CompressedLength: payloadLength,
CompressedLength: len(payload),
SentTime: t, SentTime: t,
} }
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf { switch pf {
case compressionNone: case compressionNone:
case compressionMade: case compressionMade:
@ -731,7 +789,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
} }
if !haveCompressor { if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} else {
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
} }
default: default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
@ -741,104 +803,129 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
type payloadInfo struct { type payloadInfo struct {
compressedLength int // The compressed length got from wire. compressedLength int // The compressed length got from wire.
uncompressedBytes []byte uncompressedBytes mem.BufferSlice
}
func (p *payloadInfo) free() {
if p != nil && p.uncompressedBytes != nil {
p.uncompressedBytes.Free()
}
} }
// recvAndDecompress reads a message from the stream, decompressing it if necessary. // recvAndDecompress reads a message from the stream, decompressing it if necessary.
// //
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as // Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed. // the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, // TODO: Refactor this function to reduce the number of arguments.
) (uncompressedBuf []byte, cancel func(), err error) { // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize) func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { compressedLength := compressed.Len()
return nil, nil, st.Err()
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
compressed.Free()
return nil, st.Err()
} }
var size int var size int
if pf == compressionMade { if pf.isCompressed() {
defer compressed.Free()
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default. // use this decompressor as the default.
if dc != nil { if dc != nil {
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf)) var uncompressedBuf []byte
uncompressedBuf, err = dc.Do(compressed.Reader())
if err == nil {
out = mem.BufferSlice{mem.NewBuffer(&uncompressedBuf, nil)}
}
size = len(uncompressedBuf) size = len(uncompressedBuf)
} else { } else {
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize) out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
} }
if err != nil { if err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
} }
if size > maxReceiveMessageSize { if size > maxReceiveMessageSize {
out.Free()
// TODO: Revisit the error code. Currently keep it consistent with java // TODO: Revisit the error code. Currently keep it consistent with java
// implementation. // implementation.
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
} }
} else { } else {
uncompressedBuf = compressedBuf out = compressed
} }
if payInfo != nil { if payInfo != nil {
payInfo.compressedLength = len(compressedBuf) payInfo.compressedLength = compressedLength
payInfo.uncompressedBytes = uncompressedBuf out.Ref()
payInfo.uncompressedBytes = out
cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
} }
return uncompressedBuf, cancel, nil return out, nil
} }
// Using compressor, decompress d, returning data and size. // Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size. // Optionally, if data will be over maxReceiveMessageSize, just return the size.
func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) { func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
dcReader, err := compressor.Decompress(bytes.NewReader(d)) dcReader, err := compressor.Decompress(d.Reader())
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if sizer, ok := compressor.(interface {
DecompressedSize(compressedBytes []byte) int // TODO: Can/should this still be preserved with the new BufferSlice API? Are
}); ok { // there any actual benefits to allocating a single large buffer instead of
if size := sizer.DecompressedSize(d); size >= 0 { // multiple smaller ones?
if size > maxReceiveMessageSize { //if sizer, ok := compressor.(interface {
return nil, size, nil // DecompressedSize(compressedBytes []byte) int
} //}); ok {
// size is used as an estimate to size the buffer, but we // if size := sizer.DecompressedSize(d); size >= 0 {
// will read more data if available. // if size > maxReceiveMessageSize {
// +MinRead so ReadFrom will not reallocate if size is correct. // return nil, size, nil
// // }
// TODO: If we ensure that the buffer size is the same as the DecompressedSize, // // size is used as an estimate to size the buffer, but we
// we can also utilize the recv buffer pool here. // // will read more data if available.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead)) // // +MinRead so ReadFrom will not reallocate if size is correct.
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) // //
return buf.Bytes(), int(bytesRead), err // // TODO: If we ensure that the buffer size is the same as the DecompressedSize,
} // // we can also utilize the recv buffer pool here.
// buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
// bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
// return buf.Bytes(), int(bytesRead), err
// }
//}
var out mem.BufferSlice
_, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
if err != nil {
out.Free()
return nil, 0, err
} }
// Read from LimitReader with limit max+1. So if the underlying return out, out.Len(), nil
// reader is over limit, the result will be bigger than max.
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return d, len(d), err
} }
// For the two compressor parameters, both should not be set, but if they are, // For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor. // dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API? // TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil { if err != nil {
return err return err
} }
defer cancel()
if err := c.Unmarshal(buf, m); err != nil { // If the codec wants its own reference to the data, it can get it. Otherwise, always
// free the buffers.
defer data.Free()
if err := c.Unmarshal(data, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err) return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
} }
return nil return nil
} }
@ -941,7 +1028,7 @@ func setCallInfoCodec(c *callInfo) error {
// encoding.Codec (Name vs. String method name). We only support // encoding.Codec (Name vs. String method name). We only support
// setting content subtype from encoding.Codec to avoid a behavior // setting content subtype from encoding.Codec to avoid a behavior
// change with the deprecated version. // change with the deprecated version.
if ec, ok := c.codec.(encoding.Codec); ok { if ec, ok := c.codec.(encoding.CodecV2); ok {
c.contentSubtype = strings.ToLower(ec.Name()) c.contentSubtype = strings.ToLower(ec.Name())
} }
} }
@ -950,12 +1037,12 @@ func setCallInfoCodec(c *callInfo) error {
if c.contentSubtype == "" { if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default. // No codec specified in CallOptions; use proto by default.
c.codec = encoding.GetCodec(proto.Name) c.codec = getCodec(proto.Name)
return nil return nil
} }
// c.contentSubtype is already lowercased in CallContentSubtype // c.contentSubtype is already lowercased in CallContentSubtype
c.codec = encoding.GetCodec(c.contentSubtype) c.codec = getCodec(c.contentSubtype)
if c.codec == nil { if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype) return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
} }

View File

@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -80,7 +81,7 @@ func init() {
} }
internal.BinaryLogger = binaryLogger internal.BinaryLogger = binaryLogger
internal.JoinServerOptions = newJoinServerOption internal.JoinServerOptions = newJoinServerOption
internal.RecvBufferPool = recvBufferPool internal.BufferPool = bufferPool
} }
var statusOK = status.New(codes.OK, "") var statusOK = status.New(codes.OK, "")
@ -170,7 +171,7 @@ type serverOptions struct {
maxHeaderListSize *uint32 maxHeaderListSize *uint32
headerTableSize *uint32 headerTableSize *uint32
numServerWorkers uint32 numServerWorkers uint32
recvBufferPool SharedBufferPool bufferPool mem.BufferPool
waitForHandlers bool waitForHandlers bool
} }
@ -181,7 +182,7 @@ var defaultServerOptions = serverOptions{
connectionTimeout: 120 * time.Second, connectionTimeout: 120 * time.Second,
writeBufferSize: defaultWriteBufSize, writeBufferSize: defaultWriteBufSize,
readBufferSize: defaultReadBufSize, readBufferSize: defaultReadBufSize,
recvBufferPool: nopBufferPool{}, bufferPool: mem.DefaultBufferPool(),
} }
var globalServerOptions []ServerOption var globalServerOptions []ServerOption
@ -313,7 +314,7 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
// Will be supported throughout 1.x. // Will be supported throughout 1.x.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.codec = codec o.codec = newCodecV0Bridge(codec)
}) })
} }
@ -342,7 +343,22 @@ func CustomCodec(codec Codec) ServerOption {
// later release. // later release.
func ForceServerCodec(codec encoding.Codec) ServerOption { func ForceServerCodec(codec encoding.Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.codec = codec o.codec = newCodecV1Bridge(codec)
})
}
// ForceServerCodecV2 is the equivalent of ForceServerCodec, but for the new
// CodecV2 interface.
//
// Will be supported throughout 1.x.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceServerCodecV2(codecV2 encoding.CodecV2) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.codec = codecV2
}) })
} }
@ -592,26 +608,9 @@ func WaitForHandlers(w bool) ServerOption {
}) })
} }
// RecvBufferPool returns a ServerOption that configures the server func bufferPool(bufferPool mem.BufferPool) ServerOption {
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: StatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// v1.60.0 or later.
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
return recvBufferPool(bufferPool)
}
func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.recvBufferPool = bufferPool o.bufferPool = bufferPool
}) })
} }
@ -622,7 +621,7 @@ func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
// workload (assuming a QPS of a few thousand requests/sec). // workload (assuming a QPS of a few thousand requests/sec).
const serverWorkerResetThreshold = 1 << 16 const serverWorkerResetThreshold = 1 << 16
// serverWorkers blocks on a *transport.Stream channel forever and waits for // serverWorker blocks on a *transport.Stream channel forever and waits for
// data to be fed by serveStreams. This allows multiple requests to be // data to be fed by serveStreams. This allows multiple requests to be
// processed by the same goroutine, removing the need for expensive stack // processed by the same goroutine, removing the need for expensive stack
// re-allocations (see the runtime.morestack problem [1]). // re-allocations (see the runtime.morestack problem [1]).
@ -980,6 +979,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
ChannelzParent: s.channelz, ChannelzParent: s.channelz,
MaxHeaderListSize: s.opts.maxHeaderListSize, MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize, HeaderTableSize: s.opts.headerTableSize,
BufferPool: s.opts.bufferPool,
} }
st, err := transport.NewServerTransport(c, config) st, err := transport.NewServerTransport(c, config)
if err != nil { if err != nil {
@ -1072,7 +1072,7 @@ var _ http.Handler = (*Server)(nil)
// Notice: This API is EXPERIMENTAL and may be changed or removed in a // Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool)
if err != nil { if err != nil {
// Errors returned from transport.NewServerHandlerTransport have // Errors returned from transport.NewServerHandlerTransport have
// already been written to w. // already been written to w.
@ -1142,20 +1142,35 @@ func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport,
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err) channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
return err return err
} }
compData, err := compress(data, cp, comp)
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
if err != nil { if err != nil {
data.Free()
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err) channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
return err return err
} }
hdr, payload := msgHeader(data, compData)
hdr, payload := msgHeader(data, compData, pf)
defer func() {
compData.Free()
data.Free()
// payload does not need to be freed here, it is either data or compData, both of
// which are already freed.
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > s.opts.maxSendMessageSize { if payloadLen > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
} }
err = t.Write(stream, hdr, payload, opts) err = t.Write(stream, hdr, payload, opts)
if err == nil { if err == nil {
for _, sh := range s.opts.statsHandlers { if len(s.opts.statsHandlers) != 0 {
sh.HandleRPC(ctx, outPayload(false, msg, data, payload, time.Now())) for _, sh := range s.opts.statsHandlers {
sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
}
} }
} }
return err return err
@ -1334,9 +1349,10 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
var payInfo *payloadInfo var payInfo *payloadInfo
if len(shs) != 0 || len(binlogs) != 0 { if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
if err != nil { if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil { if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
@ -1347,24 +1363,22 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
t.IncrMsgRecv() t.IncrMsgRecv()
} }
df := func(v any) error { df := func(v any) error {
defer cancel()
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
} }
for _, sh := range shs { for _, sh := range shs {
sh.HandleRPC(ctx, &stats.InPayload{ sh.HandleRPC(ctx, &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: v, Payload: v,
Length: len(d), Length: d.Len(),
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
Data: d,
}) })
} }
if len(binlogs) != 0 { if len(binlogs) != 0 {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
Message: d, Message: d.Materialize(),
} }
for _, binlog := range binlogs { for _, binlog := range binlogs {
binlog.Log(ctx, cm) binlog.Log(ctx, cm)
@ -1548,7 +1562,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
ctx: ctx, ctx: ctx,
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, p: &parser{r: stream, bufferPool: s.opts.bufferPool},
codec: s.getCodec(stream.ContentSubtype()), codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,
@ -1963,12 +1977,12 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
return s.opts.codec return s.opts.codec
} }
if contentSubtype == "" { if contentSubtype == "" {
return encoding.GetCodec(proto.Name) return getCodec(proto.Name)
} }
codec := encoding.GetCodec(contentSubtype) codec := getCodec(contentSubtype)
if codec == nil { if codec == nil {
logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name) logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
return encoding.GetCodec(proto.Name) return getCodec(proto.Name)
} }
return codec return codec
} }

View File

@ -1,154 +0,0 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import "sync"
// SharedBufferPool is a pool of buffers that can be shared, resulting in
// decreased memory allocation. Currently, in gRPC-go, it is only utilized
// for parsing incoming messages.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
type SharedBufferPool interface {
// Get returns a buffer with specified length from the pool.
//
// The returned byte slice may be not zero initialized.
Get(length int) []byte
// Put returns a buffer to the pool.
Put(*[]byte)
}
// NewSharedBufferPool creates a simple SharedBufferPool with buckets
// of different sizes to optimize memory usage. This prevents the pool from
// wasting large amounts of memory, even when handling messages of varying sizes.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func NewSharedBufferPool() SharedBufferPool {
return &simpleSharedBufferPool{
pools: [poolArraySize]simpleSharedBufferChildPool{
newBytesPool(level0PoolMaxSize),
newBytesPool(level1PoolMaxSize),
newBytesPool(level2PoolMaxSize),
newBytesPool(level3PoolMaxSize),
newBytesPool(level4PoolMaxSize),
newBytesPool(0),
},
}
}
// simpleSharedBufferPool is a simple implementation of SharedBufferPool.
type simpleSharedBufferPool struct {
pools [poolArraySize]simpleSharedBufferChildPool
}
func (p *simpleSharedBufferPool) Get(size int) []byte {
return p.pools[p.poolIdx(size)].Get(size)
}
func (p *simpleSharedBufferPool) Put(bs *[]byte) {
p.pools[p.poolIdx(cap(*bs))].Put(bs)
}
func (p *simpleSharedBufferPool) poolIdx(size int) int {
switch {
case size <= level0PoolMaxSize:
return level0PoolIdx
case size <= level1PoolMaxSize:
return level1PoolIdx
case size <= level2PoolMaxSize:
return level2PoolIdx
case size <= level3PoolMaxSize:
return level3PoolIdx
case size <= level4PoolMaxSize:
return level4PoolIdx
default:
return levelMaxPoolIdx
}
}
const (
level0PoolMaxSize = 16 // 16 B
level1PoolMaxSize = level0PoolMaxSize * 16 // 256 B
level2PoolMaxSize = level1PoolMaxSize * 16 // 4 KB
level3PoolMaxSize = level2PoolMaxSize * 16 // 64 KB
level4PoolMaxSize = level3PoolMaxSize * 16 // 1 MB
)
const (
level0PoolIdx = iota
level1PoolIdx
level2PoolIdx
level3PoolIdx
level4PoolIdx
levelMaxPoolIdx
poolArraySize
)
type simpleSharedBufferChildPool interface {
Get(size int) []byte
Put(any)
}
type bufferPool struct {
sync.Pool
defaultSize int
}
func (p *bufferPool) Get(size int) []byte {
bs := p.Pool.Get().(*[]byte)
if cap(*bs) < size {
p.Pool.Put(bs)
return make([]byte, size)
}
return (*bs)[:size]
}
func newBytesPool(size int) simpleSharedBufferChildPool {
return &bufferPool{
Pool: sync.Pool{
New: func() any {
bs := make([]byte, size)
return &bs
},
},
defaultSize: size,
}
}
// nopBufferPool is a buffer pool just makes new buffer without pooling.
type nopBufferPool struct {
}
func (nopBufferPool) Get(length int) []byte {
return make([]byte, length)
}
func (nopBufferPool) Put(*[]byte) {
}

View File

@ -77,9 +77,6 @@ type InPayload struct {
// the call to HandleRPC which provides the InPayload returns and must be // the call to HandleRPC which provides the InPayload returns and must be
// copied if needed later. // copied if needed later.
Payload any Payload any
// Data is the serialized message payload.
// Deprecated: Data will be removed in the next release.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any // Length is the size of the uncompressed payload data. Does not include any
// framing (gRPC or HTTP/2). // framing (gRPC or HTTP/2).
@ -150,9 +147,6 @@ type OutPayload struct {
// the call to HandleRPC which provides the OutPayload returns and must be // the call to HandleRPC which provides the OutPayload returns and must be
// copied if needed later. // copied if needed later.
Payload any Payload any
// Data is the serialized message payload.
// Deprecated: Data will be removed in the next release.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any // Length is the size of the uncompressed payload data. Does not include any
// framing (gRPC or HTTP/2). // framing (gRPC or HTTP/2).
Length int Length int

View File

@ -41,6 +41,7 @@ import (
"google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/serviceconfig"
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -359,7 +360,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
cs.attempt = a cs.attempt = a
return nil return nil
} }
if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil { if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) }); err != nil {
return nil, err return nil, err
} }
@ -517,7 +518,7 @@ func (a *csAttempt) newStream() error {
} }
a.s = s a.s = s
a.ctx = s.Context() a.ctx = s.Context()
a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool} a.p = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
return nil return nil
} }
@ -566,10 +567,15 @@ type clientStream struct {
// place where we need to check if the attempt is nil. // place where we need to check if the attempt is nil.
attempt *csAttempt attempt *csAttempt
// TODO(hedging): hedging will have multiple attempts simultaneously. // TODO(hedging): hedging will have multiple attempts simultaneously.
committed bool // active attempt committed for retry? committed bool // active attempt committed for retry?
onCommit func() onCommit func()
buffer []func(a *csAttempt) error // operations to replay on retry replayBuffer []replayOp // operations to replay on retry
bufferSize int // current size of buffer replayBufferSize int // current size of replayBuffer
}
type replayOp struct {
op func(a *csAttempt) error
cleanup func()
} }
// csAttempt implements a single transport stream attempt within a // csAttempt implements a single transport stream attempt within a
@ -607,7 +613,12 @@ func (cs *clientStream) commitAttemptLocked() {
cs.onCommit() cs.onCommit()
} }
cs.committed = true cs.committed = true
cs.buffer = nil for _, op := range cs.replayBuffer {
if op.cleanup != nil {
op.cleanup()
}
}
cs.replayBuffer = nil
} }
func (cs *clientStream) commitAttempt() { func (cs *clientStream) commitAttempt() {
@ -732,7 +743,7 @@ func (cs *clientStream) retryLocked(attempt *csAttempt, lastErr error) error {
// the stream is canceled. // the stream is canceled.
return err return err
} }
// Note that the first op in the replay buffer always sets cs.attempt // Note that the first op in replayBuffer always sets cs.attempt
// if it is able to pick a transport and create a stream. // if it is able to pick a transport and create a stream.
if lastErr = cs.replayBufferLocked(attempt); lastErr == nil { if lastErr = cs.replayBufferLocked(attempt); lastErr == nil {
return nil return nil
@ -761,7 +772,7 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())
// already be status errors. // already be status errors.
return toRPCErr(op(cs.attempt)) return toRPCErr(op(cs.attempt))
} }
if len(cs.buffer) == 0 { if len(cs.replayBuffer) == 0 {
// For the first op, which controls creation of the stream and // For the first op, which controls creation of the stream and
// assigns cs.attempt, we need to create a new attempt inline // assigns cs.attempt, we need to create a new attempt inline
// before executing the first op. On subsequent ops, the attempt // before executing the first op. On subsequent ops, the attempt
@ -851,25 +862,26 @@ func (cs *clientStream) Trailer() metadata.MD {
} }
func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error { func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error {
for _, f := range cs.buffer { for _, f := range cs.replayBuffer {
if err := f(attempt); err != nil { if err := f.op(attempt); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error) { func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error, cleanup func()) {
// Note: we still will buffer if retry is disabled (for transparent retries). // Note: we still will buffer if retry is disabled (for transparent retries).
if cs.committed { if cs.committed {
return return
} }
cs.bufferSize += sz cs.replayBufferSize += sz
if cs.bufferSize > cs.callInfo.maxRetryRPCBufferSize { if cs.replayBufferSize > cs.callInfo.maxRetryRPCBufferSize {
cs.commitAttemptLocked() cs.commitAttemptLocked()
cleanup()
return return
} }
cs.buffer = append(cs.buffer, op) cs.replayBuffer = append(cs.replayBuffer, replayOp{op: op, cleanup: cleanup})
} }
func (cs *clientStream) SendMsg(m any) (err error) { func (cs *clientStream) SendMsg(m any) (err error) {
@ -891,23 +903,50 @@ func (cs *clientStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp) hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.cp, cs.comp, cs.cc.dopts.copts.BufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.callInfo.maxSendMessageSize { if payloadLen > *cs.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, *cs.callInfo.maxSendMessageSize)
} }
// always take an extra ref in case data == payload (i.e. when the data isn't
// compressed). The original ref will always be freed by the deferred free above.
payload.Ref()
op := func(a *csAttempt) error { op := func(a *csAttempt) error {
return a.sendMsg(m, hdr, payload, data) return a.sendMsg(m, hdr, payload, dataLen, payloadLen)
}
// onSuccess is invoked when the op is captured for a subsequent retry. If the
// stream was established by a previous message and therefore retries are
// disabled, onSuccess will not be invoked, and payloadRef can be freed
// immediately.
onSuccessCalled := false
err = cs.withRetry(op, func() {
cs.bufferForRetryLocked(len(hdr)+payloadLen, op, payload.Free)
onSuccessCalled = true
})
if !onSuccessCalled {
payload.Free()
} }
err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
if len(cs.binlogs) != 0 && err == nil { if len(cs.binlogs) != 0 && err == nil {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
OnClientSide: true, OnClientSide: true,
Message: data, Message: data.Materialize(),
} }
for _, binlog := range cs.binlogs { for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, cm) binlog.Log(cs.ctx, cm)
@ -924,6 +963,7 @@ func (cs *clientStream) RecvMsg(m any) error {
var recvInfo *payloadInfo var recvInfo *payloadInfo
if len(cs.binlogs) != 0 { if len(cs.binlogs) != 0 {
recvInfo = &payloadInfo{} recvInfo = &payloadInfo{}
defer recvInfo.free()
} }
err := cs.withRetry(func(a *csAttempt) error { err := cs.withRetry(func(a *csAttempt) error {
return a.recvMsg(m, recvInfo) return a.recvMsg(m, recvInfo)
@ -931,7 +971,7 @@ func (cs *clientStream) RecvMsg(m any) error {
if len(cs.binlogs) != 0 && err == nil { if len(cs.binlogs) != 0 && err == nil {
sm := &binarylog.ServerMessage{ sm := &binarylog.ServerMessage{
OnClientSide: true, OnClientSide: true,
Message: recvInfo.uncompressedBytes, Message: recvInfo.uncompressedBytes.Materialize(),
} }
for _, binlog := range cs.binlogs { for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, sm) binlog.Log(cs.ctx, sm)
@ -958,7 +998,7 @@ func (cs *clientStream) CloseSend() error {
// RecvMsg. This also matches historical behavior. // RecvMsg. This also matches historical behavior.
return nil return nil
} }
cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }) cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) })
if len(cs.binlogs) != 0 { if len(cs.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{ chc := &binarylog.ClientHalfClose{
OnClientSide: true, OnClientSide: true,
@ -1034,7 +1074,7 @@ func (cs *clientStream) finish(err error) {
cs.cancel() cs.cancel()
} }
func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error { func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength, payloadLength int) error {
cs := a.cs cs := a.cs
if a.trInfo != nil { if a.trInfo != nil {
a.mu.Lock() a.mu.Lock()
@ -1052,8 +1092,10 @@ func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error {
} }
return io.EOF return io.EOF
} }
for _, sh := range a.statsHandlers { if len(a.statsHandlers) != 0 {
sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) for _, sh := range a.statsHandlers {
sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
}
} }
if channelz.IsOn() { if channelz.IsOn() {
a.t.IncrMsgSent() a.t.IncrMsgSent()
@ -1065,6 +1107,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
cs := a.cs cs := a.cs
if len(a.statsHandlers) != 0 && payInfo == nil { if len(a.statsHandlers) != 0 && payInfo == nil {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
if !a.decompSet { if !a.decompSet {
@ -1083,8 +1126,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
a.decompSet = true a.decompSet = true
} }
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp) if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil {
if err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil { if statusErr := a.s.Status().Err(); statusErr != nil {
return statusErr return statusErr
@ -1103,14 +1145,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} }
for _, sh := range a.statsHandlers { for _, sh := range a.statsHandlers {
sh.HandleRPC(a.ctx, &stats.InPayload{ sh.HandleRPC(a.ctx, &stats.InPayload{
Client: true, Client: true,
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
// TODO truncate large payload.
Data: payInfo.uncompressedBytes,
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
Length: len(payInfo.uncompressedBytes), Length: payInfo.uncompressedBytes.Len(),
}) })
} }
if channelz.IsOn() { if channelz.IsOn() {
@ -1122,14 +1162,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} }
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp) if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == io.EOF {
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
return a.s.Status().Err() // non-server streaming Recv returns nil on success return a.s.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
} }
return toRPCErr(err) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
func (a *csAttempt) finish(err error) { func (a *csAttempt) finish(err error) {
@ -1185,12 +1223,12 @@ func (a *csAttempt) finish(err error) {
a.mu.Unlock() a.mu.Unlock()
} }
// newClientStream creates a ClientStream with the specified transport, on the // newNonRetryClientStream creates a ClientStream with the specified transport, on the
// given addrConn. // given addrConn.
// //
// It's expected that the given transport is either the same one in addrConn, or // It's expected that the given transport is either the same one in addrConn, or
// is already closed. To avoid race, transport is specified separately, instead // is already closed. To avoid race, transport is specified separately, instead
// of using ac.transpot. // of using ac.transport.
// //
// Main difference between this and ClientConn.NewStream: // Main difference between this and ClientConn.NewStream:
// - no retry // - no retry
@ -1276,7 +1314,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
return nil, err return nil, err
} }
as.s = s as.s = s
as.p = &parser{r: s, recvBufferPool: ac.dopts.recvBufferPool} as.p = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
ac.incrCallsStarted() ac.incrCallsStarted()
if desc != unaryStreamDesc { if desc != unaryStreamDesc {
// Listen on stream context to cleanup when the stream context is // Listen on stream context to cleanup when the stream context is
@ -1373,17 +1411,26 @@ func (as *addrConnStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp) hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.cp, as.comp, as.ac.dopts.copts.BufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payld) > *as.callInfo.maxSendMessageSize { if payload.Len() > *as.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payload.Len(), *as.callInfo.maxSendMessageSize)
} }
if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil { if err := as.t.Write(as.s, hdr, payload, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
if !as.desc.ClientStreams { if !as.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on error // For non-client-streaming RPCs, we return nil instead of EOF on error
// because the generated code requires it. finish is not called; RecvMsg() // because the generated code requires it. finish is not called; RecvMsg()
@ -1423,8 +1470,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
as.decompSet = true as.decompSet = true
} }
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil {
if err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := as.s.Status().Err(); statusErr != nil { if statusErr := as.s.Status().Err(); statusErr != nil {
return statusErr return statusErr
@ -1444,14 +1490,12 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == io.EOF {
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
return as.s.Status().Err() // non-server streaming Recv returns nil on success return as.s.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
} }
return toRPCErr(err) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
func (as *addrConnStream) finish(err error) { func (as *addrConnStream) finish(err error) {
@ -1645,18 +1689,31 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp) hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.cp, ss.comp, ss.p.bufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize { if payloadLen > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, ss.maxSendMessageSize)
} }
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
if !ss.serverHeaderBinlogged { if !ss.serverHeaderBinlogged {
h, _ := ss.s.Header() h, _ := ss.s.Header()
@ -1669,7 +1726,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
} }
sm := &binarylog.ServerMessage{ sm := &binarylog.ServerMessage{
Message: data, Message: data.Materialize(),
} }
for _, binlog := range ss.binlogs { for _, binlog := range ss.binlogs {
binlog.Log(ss.ctx, sm) binlog.Log(ss.ctx, sm)
@ -1677,7 +1734,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
if len(ss.statsHandler) != 0 { if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler { for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
} }
} }
return nil return nil
@ -1714,8 +1771,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
var payInfo *payloadInfo var payInfo *payloadInfo
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
if err == io.EOF { if err == io.EOF {
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{} chc := &binarylog.ClientHalfClose{}
@ -1733,11 +1791,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
if len(ss.statsHandler) != 0 { if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler { for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), &stats.InPayload{ sh.HandleRPC(ss.s.Context(), &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
// TODO truncate large payload. Length: payInfo.uncompressedBytes.Len(),
Data: payInfo.uncompressedBytes,
Length: len(payInfo.uncompressedBytes),
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
}) })
@ -1745,7 +1801,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
} }
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
Message: payInfo.uncompressedBytes, Message: payInfo.uncompressedBytes.Materialize(),
} }
for _, binlog := range ss.binlogs { for _, binlog := range ss.binlogs {
binlog.Log(ss.ctx, cm) binlog.Log(ss.ctx, cm)
@ -1760,23 +1816,26 @@ func MethodFromServerStream(stream ServerStream) (string, bool) {
return Method(stream.Context()) return Method(stream.Context())
} }
// prepareMsg returns the hdr, payload and data // prepareMsg returns the hdr, payload and data using the compressors passed or
// using the compressors passed or using the // using the passed preparedmsg. The returned boolean indicates whether
// passed preparedmsg // compression was made and therefore whether the payload needs to be freed in
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) { // addition to the returned data. Freeing the payload if the returned boolean is
// false can lead to undefined behavior.
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) {
if preparedMsg, ok := m.(*PreparedMsg); ok { if preparedMsg, ok := m.(*PreparedMsg); ok {
return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil return preparedMsg.hdr, preparedMsg.encodedData, preparedMsg.payload, preparedMsg.pf, nil
} }
// The input interface is not a prepared msg. // The input interface is not a prepared msg.
// Marshal and Compress the data at this point // Marshal and Compress the data at this point
data, err = encode(codec, m) data, err = encode(codec, m)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, 0, err
} }
compData, err := compress(data, cp, comp) compData, pf, err := compress(data, cp, comp, pool)
if err != nil { if err != nil {
return nil, nil, nil, err data.Free()
return nil, nil, nil, 0, err
} }
hdr, payload = msgHeader(data, compData) hdr, payload = msgHeader(data, compData, pf)
return hdr, payload, data, nil return hdr, data, payload, pf, nil
} }

View File

@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.65.0" const Version = "1.66.0"

8
vendor/modules.txt vendored
View File

@ -826,7 +826,7 @@ golang.org/x/tools/go/ast/inspector
# gomodules.xyz/jsonpatch/v2 v2.4.0 => github.com/gomodules/jsonpatch/v2 v2.2.0 # gomodules.xyz/jsonpatch/v2 v2.4.0 => github.com/gomodules/jsonpatch/v2 v2.2.0
## explicit; go 1.12 ## explicit; go 1.12
gomodules.xyz/jsonpatch/v2 gomodules.xyz/jsonpatch/v2
# google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 # google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117
## explicit; go 1.20 ## explicit; go 1.20
google.golang.org/genproto/googleapis/api google.golang.org/genproto/googleapis/api
google.golang.org/genproto/googleapis/api/annotations google.golang.org/genproto/googleapis/api/annotations
@ -836,7 +836,7 @@ google.golang.org/genproto/googleapis/api/httpbody
## explicit; go 1.20 ## explicit; go 1.20
google.golang.org/genproto/googleapis/rpc/errdetails google.golang.org/genproto/googleapis/rpc/errdetails
google.golang.org/genproto/googleapis/rpc/status google.golang.org/genproto/googleapis/rpc/status
# google.golang.org/grpc v1.65.0 # google.golang.org/grpc v1.66.0
## explicit; go 1.21 ## explicit; go 1.21
google.golang.org/grpc google.golang.org/grpc
google.golang.org/grpc/attributes google.golang.org/grpc/attributes
@ -855,7 +855,9 @@ google.golang.org/grpc/credentials/insecure
google.golang.org/grpc/encoding google.golang.org/grpc/encoding
google.golang.org/grpc/encoding/gzip google.golang.org/grpc/encoding/gzip
google.golang.org/grpc/encoding/proto google.golang.org/grpc/encoding/proto
google.golang.org/grpc/experimental/stats
google.golang.org/grpc/grpclog google.golang.org/grpc/grpclog
google.golang.org/grpc/grpclog/internal
google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/health/grpc_health_v1
google.golang.org/grpc/internal google.golang.org/grpc/internal
google.golang.org/grpc/internal/backoff google.golang.org/grpc/internal/backoff
@ -878,11 +880,13 @@ google.golang.org/grpc/internal/resolver/dns/internal
google.golang.org/grpc/internal/resolver/passthrough google.golang.org/grpc/internal/resolver/passthrough
google.golang.org/grpc/internal/resolver/unix google.golang.org/grpc/internal/resolver/unix
google.golang.org/grpc/internal/serviceconfig google.golang.org/grpc/internal/serviceconfig
google.golang.org/grpc/internal/stats
google.golang.org/grpc/internal/status google.golang.org/grpc/internal/status
google.golang.org/grpc/internal/syscall google.golang.org/grpc/internal/syscall
google.golang.org/grpc/internal/transport google.golang.org/grpc/internal/transport
google.golang.org/grpc/internal/transport/networktype google.golang.org/grpc/internal/transport/networktype
google.golang.org/grpc/keepalive google.golang.org/grpc/keepalive
google.golang.org/grpc/mem
google.golang.org/grpc/metadata google.golang.org/grpc/metadata
google.golang.org/grpc/peer google.golang.org/grpc/peer
google.golang.org/grpc/resolver google.golang.org/grpc/resolver