rebase: bump google.golang.org/grpc from 1.59.0 to 1.60.1

Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.59.0 to 1.60.1.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.59.0...v1.60.1)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot] 2024-01-04 07:36:42 +00:00 committed by mergify[bot]
parent c807059618
commit 0ec64b7552
52 changed files with 1970 additions and 1670 deletions

6
go.mod
View File

@ -30,7 +30,7 @@ require (
golang.org/x/crypto v0.17.0 golang.org/x/crypto v0.17.0
golang.org/x/net v0.19.0 golang.org/x/net v0.19.0
golang.org/x/sys v0.15.0 golang.org/x/sys v0.15.0
google.golang.org/grpc v1.59.0 google.golang.org/grpc v1.60.1
google.golang.org/protobuf v1.32.0 google.golang.org/protobuf v1.32.0
// //
// when updating k8s.io/kubernetes, make sure to update the replace section too // when updating k8s.io/kubernetes, make sure to update the replace section too
@ -149,14 +149,14 @@ require (
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.25.0 // indirect go.uber.org/zap v1.25.0 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/oauth2 v0.12.0 // indirect golang.org/x/oauth2 v0.13.0 // indirect
golang.org/x/sync v0.4.0 // indirect golang.org/x/sync v0.4.0 // indirect
golang.org/x/term v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.14.0 // indirect golang.org/x/tools v0.14.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20231030173426-d783a09b4405 // indirect google.golang.org/genproto v0.0.0-20231030173426-d783a09b4405 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect

11
go.sum
View File

@ -1913,8 +1913,8 @@ golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI= golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI=
golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4= golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY=
golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4= golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -2266,8 +2266,9 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
@ -2485,8 +2486,8 @@ google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpX
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.60.1 h1:26+wFr+cNqSGFcOXcabYC0lUVJVRa2Sb2ortSK7VrEU=
google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/grpc v1.60.1/go.mod h1:OlCHIeLYqSSsLi6i49B5QGdzaMZK9+M7LXN2FKz4eGM=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

198
vendor/golang.org/x/oauth2/deviceauth.go generated vendored Normal file
View File

@ -0,0 +1,198 @@
package oauth2
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry time.Time `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
type Alias DeviceAuthResponse
var expiresIn int64
if !d.Expiry.IsZero() {
expiresIn = int64(time.Until(d.Expiry).Seconds())
}
return json.Marshal(&struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
*Alias
}{
ExpiresIn: expiresIn,
Alias: (*Alias)(&d),
})
}
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
type Alias DeviceAuthResponse
aux := &struct {
ExpiresIn int64 `json:"expires_in"`
// workaround misspelling of verification_uri
VerificationURL string `json:"verification_url"`
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ExpiresIn != 0 {
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
}
if c.VerificationURI == "" {
c.VerificationURI = aux.VerificationURL
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}
da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}
if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}
return da, nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
if !da.Expiry.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
defer cancel()
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da.Interval
if interval == 0 {
interval = 5
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}
e, ok := err.(*RetrieveError)
if !ok {
return nil, err
}
switch e.ErrorCode {
case errSlowDown:
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker.Reset(time.Duration(interval) * time.Second)
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}
}

29
vendor/golang.org/x/oauth2/oauth2.go generated vendored
View File

@ -75,8 +75,9 @@ type TokenSource interface {
// Endpoint represents an OAuth 2.0 provider's authorization and token // Endpoint represents an OAuth 2.0 provider's authorization and token
// endpoint URLs. // endpoint URLs.
type Endpoint struct { type Endpoint struct {
AuthURL string AuthURL string
TokenURL string DeviceAuthURL string
TokenURL string
// AuthStyle optionally specifies how the endpoint wants the // AuthStyle optionally specifies how the endpoint wants the
// client ID & client secret sent. The zero value means to // client ID & client secret sent. The zero value means to
@ -143,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly. // that asks for permissions for the required scopes explicitly.
// //
// State is a token to protect the user from CSRF attacks. You must // State is an opaque value used by the client to maintain state between the
// always provide a non-empty string and validate that it matches the // request and callback. The authorization server includes this value when
// state query parameter on your redirect callback. // redirecting the user agent back to the client.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// //
// Opts may include AccessTypeOnline or AccessTypeOffline, as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce. // as ApprovalForce.
// It can also be used to pass the PKCE challenge. //
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
// generate a random state parameter and verify it after exchange.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL) buf.WriteString(c.Endpoint.AuthURL)
@ -166,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
v.Set("scope", strings.Join(c.Scopes, " ")) v.Set("scope", strings.Join(c.Scopes, " "))
} }
if state != "" { if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state) v.Set("state", state)
} }
for _, opt := range opts { for _, opt := range opts {
@ -211,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. // The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
// //
// Opts may include the PKCE verifier code if previously used in AuthCodeURL. // If using PKCE to protect against CSRF attacks, opts should include a
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // VerifierOption.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) { func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{ v := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},

68
vendor/golang.org/x/oauth2/pkce.go generated vendored Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/url"
)
const (
codeChallengeKey = "code_challenge"
codeChallengeMethodKey = "code_challenge_method"
codeVerifierKey = "code_verifier"
)
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
// base64url-encoded to produce a 43-octet URL-safe string to use as the
// code verifier."
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
data := make([]byte, 32)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{
challenge_method: "S256",
challenge: S256ChallengeFromVerifier(verifier),
}
}
type challengeOption struct{ challenge_method, challenge string }
func (p challengeOption) setValue(m url.Values) {
m.Set(codeChallengeMethodKey, p.challenge_method)
m.Set(codeChallengeKey, p.challenge)
}

View File

@ -2,12 +2,14 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !appengine
// +build !appengine // +build !appengine
package internal package internal
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -24,7 +26,6 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"
basepb "google.golang.org/appengine/internal/base" basepb "google.golang.org/appengine/internal/base"
logpb "google.golang.org/appengine/internal/log" logpb "google.golang.org/appengine/internal/log"
@ -32,8 +33,7 @@ import (
) )
const ( const (
apiPath = "/rpc_http" apiPath = "/rpc_http"
defaultTicketSuffix = "/default.20150612t184001.0"
) )
var ( var (
@ -65,21 +65,22 @@ var (
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
}, },
} }
defaultTicketOnce sync.Once
defaultTicket string
backgroundContextOnce sync.Once
backgroundContext netcontext.Context
) )
func apiURL() *url.URL { func apiURL(ctx context.Context) *url.URL {
host, port := "appengine.googleapis.internal", "10001" host, port := "appengine.googleapis.internal", "10001"
if h := os.Getenv("API_HOST"); h != "" { if h := os.Getenv("API_HOST"); h != "" {
host = h host = h
} }
if hostOverride := ctx.Value(apiHostOverrideKey); hostOverride != nil {
host = hostOverride.(string)
}
if p := os.Getenv("API_PORT"); p != "" { if p := os.Getenv("API_PORT"); p != "" {
port = p port = p
} }
if portOverride := ctx.Value(apiPortOverrideKey); portOverride != nil {
port = portOverride.(string)
}
return &url.URL{ return &url.URL{
Scheme: "http", Scheme: "http",
Host: host + ":" + port, Host: host + ":" + port,
@ -87,82 +88,97 @@ func apiURL() *url.URL {
} }
} }
func handleHTTP(w http.ResponseWriter, r *http.Request) { // Middleware wraps an http handler so that it can make GAE API calls
c := &context{ func Middleware(next http.Handler) http.Handler {
req: r, return handleHTTPMiddleware(executeRequestSafelyMiddleware(next))
outHeader: w.Header(),
apiURL: apiURL(),
}
r = r.WithContext(withContext(r.Context(), c))
c.req = r
stopFlushing := make(chan int)
// Patch up RemoteAddr so it looks reasonable.
if addr := r.Header.Get(userIPHeader); addr != "" {
r.RemoteAddr = addr
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
r.RemoteAddr = addr
} else {
// Should not normally reach here, but pick a sensible default anyway.
r.RemoteAddr = "127.0.0.1"
}
// The address in the headers will most likely be of these forms:
// 123.123.123.123
// 2001:db8::1
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
// Assume the remote address is only a host; add a default port.
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}
// Start goroutine responsible for flushing app logs.
// This is done after adding c to ctx.m (and stopped before removing it)
// because flushing logs requires making an API call.
go c.logFlusher(stopFlushing)
executeRequestSafely(c, r)
c.outHeader = nil // make sure header changes aren't respected any more
stopFlushing <- 1 // any logging beyond this point will be dropped
// Flush any pending logs asynchronously.
c.pendingLogs.Lock()
flushes := c.pendingLogs.flushes
if len(c.pendingLogs.lines) > 0 {
flushes++
}
c.pendingLogs.Unlock()
flushed := make(chan struct{})
go func() {
defer close(flushed)
// Force a log flush, because with very short requests we
// may not ever flush logs.
c.flushLog(true)
}()
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
// Avoid nil Write call if c.Write is never called.
if c.outCode != 0 {
w.WriteHeader(c.outCode)
}
if c.outBody != nil {
w.Write(c.outBody)
}
// Wait for the last flush to complete before returning,
// otherwise the security ticket will not be valid.
<-flushed
} }
func executeRequestSafely(c *context, r *http.Request) { func handleHTTPMiddleware(next http.Handler) http.Handler {
defer func() { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if x := recover(); x != nil { c := &aeContext{
logf(c, 4, "%s", renderPanic(x)) // 4 == critical req: r,
c.outCode = 500 outHeader: w.Header(),
} }
}() r = r.WithContext(withContext(r.Context(), c))
c.req = r
http.DefaultServeMux.ServeHTTP(c, r) stopFlushing := make(chan int)
// Patch up RemoteAddr so it looks reasonable.
if addr := r.Header.Get(userIPHeader); addr != "" {
r.RemoteAddr = addr
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
r.RemoteAddr = addr
} else {
// Should not normally reach here, but pick a sensible default anyway.
r.RemoteAddr = "127.0.0.1"
}
// The address in the headers will most likely be of these forms:
// 123.123.123.123
// 2001:db8::1
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
// Assume the remote address is only a host; add a default port.
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}
if logToLogservice() {
// Start goroutine responsible for flushing app logs.
// This is done after adding c to ctx.m (and stopped before removing it)
// because flushing logs requires making an API call.
go c.logFlusher(stopFlushing)
}
next.ServeHTTP(c, r)
c.outHeader = nil // make sure header changes aren't respected any more
flushed := make(chan struct{})
if logToLogservice() {
stopFlushing <- 1 // any logging beyond this point will be dropped
// Flush any pending logs asynchronously.
c.pendingLogs.Lock()
flushes := c.pendingLogs.flushes
if len(c.pendingLogs.lines) > 0 {
flushes++
}
c.pendingLogs.Unlock()
go func() {
defer close(flushed)
// Force a log flush, because with very short requests we
// may not ever flush logs.
c.flushLog(true)
}()
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
}
// Avoid nil Write call if c.Write is never called.
if c.outCode != 0 {
w.WriteHeader(c.outCode)
}
if c.outBody != nil {
w.Write(c.outBody)
}
if logToLogservice() {
// Wait for the last flush to complete before returning,
// otherwise the security ticket will not be valid.
<-flushed
}
})
}
func executeRequestSafelyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if x := recover(); x != nil {
c := w.(*aeContext)
logf(c, 4, "%s", renderPanic(x)) // 4 == critical
c.outCode = 500
}
}()
next.ServeHTTP(w, r)
})
} }
func renderPanic(x interface{}) string { func renderPanic(x interface{}) string {
@ -204,9 +220,9 @@ func renderPanic(x interface{}) string {
return string(buf) return string(buf)
} }
// context represents the context of an in-flight HTTP request. // aeContext represents the aeContext of an in-flight HTTP request.
// It implements the appengine.Context and http.ResponseWriter interfaces. // It implements the appengine.Context and http.ResponseWriter interfaces.
type context struct { type aeContext struct {
req *http.Request req *http.Request
outCode int outCode int
@ -218,8 +234,6 @@ type context struct {
lines []*logpb.UserAppLogLine lines []*logpb.UserAppLogLine
flushes int flushes int
} }
apiURL *url.URL
} }
var contextKey = "holds a *context" var contextKey = "holds a *context"
@ -227,8 +241,8 @@ var contextKey = "holds a *context"
// jointContext joins two contexts in a superficial way. // jointContext joins two contexts in a superficial way.
// It takes values and timeouts from a base context, and only values from another context. // It takes values and timeouts from a base context, and only values from another context.
type jointContext struct { type jointContext struct {
base netcontext.Context base context.Context
valuesOnly netcontext.Context valuesOnly context.Context
} }
func (c jointContext) Deadline() (time.Time, bool) { func (c jointContext) Deadline() (time.Time, bool) {
@ -252,94 +266,54 @@ func (c jointContext) Value(key interface{}) interface{} {
// fromContext returns the App Engine context or nil if ctx is not // fromContext returns the App Engine context or nil if ctx is not
// derived from an App Engine context. // derived from an App Engine context.
func fromContext(ctx netcontext.Context) *context { func fromContext(ctx context.Context) *aeContext {
c, _ := ctx.Value(&contextKey).(*context) c, _ := ctx.Value(&contextKey).(*aeContext)
return c return c
} }
func withContext(parent netcontext.Context, c *context) netcontext.Context { func withContext(parent context.Context, c *aeContext) context.Context {
ctx := netcontext.WithValue(parent, &contextKey, c) ctx := context.WithValue(parent, &contextKey, c)
if ns := c.req.Header.Get(curNamespaceHeader); ns != "" { if ns := c.req.Header.Get(curNamespaceHeader); ns != "" {
ctx = withNamespace(ctx, ns) ctx = withNamespace(ctx, ns)
} }
return ctx return ctx
} }
func toContext(c *context) netcontext.Context { func toContext(c *aeContext) context.Context {
return withContext(netcontext.Background(), c) return withContext(context.Background(), c)
} }
func IncomingHeaders(ctx netcontext.Context) http.Header { func IncomingHeaders(ctx context.Context) http.Header {
if c := fromContext(ctx); c != nil { if c := fromContext(ctx); c != nil {
return c.req.Header return c.req.Header
} }
return nil return nil
} }
func ReqContext(req *http.Request) netcontext.Context { func ReqContext(req *http.Request) context.Context {
return req.Context() return req.Context()
} }
func WithContext(parent netcontext.Context, req *http.Request) netcontext.Context { func WithContext(parent context.Context, req *http.Request) context.Context {
return jointContext{ return jointContext{
base: parent, base: parent,
valuesOnly: req.Context(), valuesOnly: req.Context(),
} }
} }
// DefaultTicket returns a ticket used for background context or dev_appserver.
func DefaultTicket() string {
defaultTicketOnce.Do(func() {
if IsDevAppServer() {
defaultTicket = "testapp" + defaultTicketSuffix
return
}
appID := partitionlessAppID()
escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1)
majVersion := VersionID(nil)
if i := strings.Index(majVersion, "."); i > 0 {
majVersion = majVersion[:i]
}
defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID())
})
return defaultTicket
}
func BackgroundContext() netcontext.Context {
backgroundContextOnce.Do(func() {
// Compute background security ticket.
ticket := DefaultTicket()
c := &context{
req: &http.Request{
Header: http.Header{
ticketHeader: []string{ticket},
},
},
apiURL: apiURL(),
}
backgroundContext = toContext(c)
// TODO(dsymonds): Wire up the shutdown handler to do a final flush.
go c.logFlusher(make(chan int))
})
return backgroundContext
}
// RegisterTestRequest registers the HTTP request req for testing, such that // RegisterTestRequest registers the HTTP request req for testing, such that
// any API calls are sent to the provided URL. It returns a closure to delete // any API calls are sent to the provided URL.
// the registration.
// It should only be used by aetest package. // It should only be used by aetest package.
func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request {
c := &context{ ctx := req.Context()
req: req, ctx = withAPIHostOverride(ctx, apiURL.Hostname())
apiURL: apiURL, ctx = withAPIPortOverride(ctx, apiURL.Port())
} ctx = WithAppIDOverride(ctx, appID)
ctx := withContext(decorate(req.Context()), c)
req = req.WithContext(ctx) // use the unregistered request as a placeholder so that withContext can read the headers
c.req = req c := &aeContext{req: req}
return req, func() {} c.req = req.WithContext(withContext(ctx, c))
return c.req
} }
var errTimeout = &CallError{ var errTimeout = &CallError{
@ -348,7 +322,7 @@ var errTimeout = &CallError{
Timeout: true, Timeout: true,
} }
func (c *context) Header() http.Header { return c.outHeader } func (c *aeContext) Header() http.Header { return c.outHeader }
// Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status // Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status
// codes do not permit a response body (nor response entity headers such as // codes do not permit a response body (nor response entity headers such as
@ -365,7 +339,7 @@ func bodyAllowedForStatus(status int) bool {
return true return true
} }
func (c *context) Write(b []byte) (int, error) { func (c *aeContext) Write(b []byte) (int, error) {
if c.outCode == 0 { if c.outCode == 0 {
c.WriteHeader(http.StatusOK) c.WriteHeader(http.StatusOK)
} }
@ -376,7 +350,7 @@ func (c *context) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (c *context) WriteHeader(code int) { func (c *aeContext) WriteHeader(code int) {
if c.outCode != 0 { if c.outCode != 0 {
logf(c, 3, "WriteHeader called multiple times on request.") // error level logf(c, 3, "WriteHeader called multiple times on request.") // error level
return return
@ -384,10 +358,11 @@ func (c *context) WriteHeader(code int) {
c.outCode = code c.outCode = code
} }
func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { func post(ctx context.Context, body []byte, timeout time.Duration) (b []byte, err error) {
apiURL := apiURL(ctx)
hreq := &http.Request{ hreq := &http.Request{
Method: "POST", Method: "POST",
URL: c.apiURL, URL: apiURL,
Header: http.Header{ Header: http.Header{
apiEndpointHeader: apiEndpointHeaderValue, apiEndpointHeader: apiEndpointHeaderValue,
apiMethodHeader: apiMethodHeaderValue, apiMethodHeader: apiMethodHeaderValue,
@ -396,13 +371,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error)
}, },
Body: ioutil.NopCloser(bytes.NewReader(body)), Body: ioutil.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)), ContentLength: int64(len(body)),
Host: c.apiURL.Host, Host: apiURL.Host,
} }
if info := c.req.Header.Get(dapperHeader); info != "" { c := fromContext(ctx)
hreq.Header.Set(dapperHeader, info) if c != nil {
} if info := c.req.Header.Get(dapperHeader); info != "" {
if info := c.req.Header.Get(traceHeader); info != "" { hreq.Header.Set(dapperHeader, info)
hreq.Header.Set(traceHeader, info) }
if info := c.req.Header.Get(traceHeader); info != "" {
hreq.Header.Set(traceHeader, info)
}
} }
tr := apiHTTPClient.Transport.(*http.Transport) tr := apiHTTPClient.Transport.(*http.Transport)
@ -444,7 +422,7 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error)
return hrespBody, nil return hrespBody, nil
} }
func Call(ctx netcontext.Context, service, method string, in, out proto.Message) error { func Call(ctx context.Context, service, method string, in, out proto.Message) error {
if ns := NamespaceFromContext(ctx); ns != "" { if ns := NamespaceFromContext(ctx); ns != "" {
if fn, ok := NamespaceMods[service]; ok { if fn, ok := NamespaceMods[service]; ok {
fn(in, ns) fn(in, ns)
@ -463,15 +441,11 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
} }
c := fromContext(ctx) c := fromContext(ctx)
if c == nil {
// Give a good error message rather than a panic lower down.
return errNotAppEngineContext
}
// Apply transaction modifications if we're in a transaction. // Apply transaction modifications if we're in a transaction.
if t := transactionFromContext(ctx); t != nil { if t := transactionFromContext(ctx); t != nil {
if t.finished { if t.finished {
return errors.New("transaction context has expired") return errors.New("transaction aeContext has expired")
} }
applyTransaction(in, &t.transaction) applyTransaction(in, &t.transaction)
} }
@ -487,20 +461,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err return err
} }
ticket := c.req.Header.Get(ticketHeader) ticket := ""
// Use a test ticket under test environment. if c != nil {
if ticket == "" { ticket = c.req.Header.Get(ticketHeader)
if appid := ctx.Value(&appIDOverrideKey); appid != nil { if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" {
ticket = appid.(string) + defaultTicketSuffix ticket = dri
} }
} }
// Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver.
if ticket == "" {
ticket = DefaultTicket()
}
if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" {
ticket = dri
}
req := &remotepb.Request{ req := &remotepb.Request{
ServiceName: &service, ServiceName: &service,
Method: &method, Method: &method,
@ -512,7 +479,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err return err
} }
hrespBody, err := c.post(hreqBody, timeout) hrespBody, err := post(ctx, hreqBody, timeout)
if err != nil { if err != nil {
return err return err
} }
@ -549,11 +516,11 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return proto.Unmarshal(res.Response, out) return proto.Unmarshal(res.Response, out)
} }
func (c *context) Request() *http.Request { func (c *aeContext) Request() *http.Request {
return c.req return c.req
} }
func (c *context) addLogLine(ll *logpb.UserAppLogLine) { func (c *aeContext) addLogLine(ll *logpb.UserAppLogLine) {
// Truncate long log lines. // Truncate long log lines.
// TODO(dsymonds): Check if this is still necessary. // TODO(dsymonds): Check if this is still necessary.
const lim = 8 << 10 const lim = 8 << 10
@ -575,18 +542,20 @@ var logLevelName = map[int64]string{
4: "CRITICAL", 4: "CRITICAL",
} }
func logf(c *context, level int64, format string, args ...interface{}) { func logf(c *aeContext, level int64, format string, args ...interface{}) {
if c == nil { if c == nil {
panic("not an App Engine context") panic("not an App Engine aeContext")
} }
s := fmt.Sprintf(format, args...) s := fmt.Sprintf(format, args...)
s = strings.TrimRight(s, "\n") // Remove any trailing newline characters. s = strings.TrimRight(s, "\n") // Remove any trailing newline characters.
c.addLogLine(&logpb.UserAppLogLine{ if logToLogservice() {
TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3), c.addLogLine(&logpb.UserAppLogLine{
Level: &level, TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3),
Message: &s, Level: &level,
}) Message: &s,
// Only duplicate log to stderr if not running on App Engine second generation })
}
// Log to stdout if not deployed
if !IsSecondGen() { if !IsSecondGen() {
log.Print(logLevelName[level] + ": " + s) log.Print(logLevelName[level] + ": " + s)
} }
@ -594,7 +563,7 @@ func logf(c *context, level int64, format string, args ...interface{}) {
// flushLog attempts to flush any pending logs to the appserver. // flushLog attempts to flush any pending logs to the appserver.
// It should not be called concurrently. // It should not be called concurrently.
func (c *context) flushLog(force bool) (flushed bool) { func (c *aeContext) flushLog(force bool) (flushed bool) {
c.pendingLogs.Lock() c.pendingLogs.Lock()
// Grab up to 30 MB. We can get away with up to 32 MB, but let's be cautious. // Grab up to 30 MB. We can get away with up to 32 MB, but let's be cautious.
n, rem := 0, 30<<20 n, rem := 0, 30<<20
@ -655,7 +624,7 @@ const (
forceFlushInterval = 60 * time.Second forceFlushInterval = 60 * time.Second
) )
func (c *context) logFlusher(stop <-chan int) { func (c *aeContext) logFlusher(stop <-chan int) {
lastFlush := time.Now() lastFlush := time.Now()
tick := time.NewTicker(flushInterval) tick := time.NewTicker(flushInterval)
for { for {
@ -673,6 +642,12 @@ func (c *context) logFlusher(stop <-chan int) {
} }
} }
func ContextForTesting(req *http.Request) netcontext.Context { func ContextForTesting(req *http.Request) context.Context {
return toContext(&context{req: req}) return toContext(&aeContext{req: req})
}
func logToLogservice() bool {
// TODO: replace logservice with json structured logs to $LOG_DIR/app.log.json
// where $LOG_DIR is /var/log in prod and some tmpdir in dev
return os.Getenv("LOG_TO_LOGSERVICE") != "0"
} }

View File

@ -2,11 +2,13 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build appengine
// +build appengine // +build appengine
package internal package internal
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -17,20 +19,19 @@ import (
basepb "appengine_internal/base" basepb "appengine_internal/base"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"
) )
var contextKey = "holds an appengine.Context" var contextKey = "holds an appengine.Context"
// fromContext returns the App Engine context or nil if ctx is not // fromContext returns the App Engine context or nil if ctx is not
// derived from an App Engine context. // derived from an App Engine context.
func fromContext(ctx netcontext.Context) appengine.Context { func fromContext(ctx context.Context) appengine.Context {
c, _ := ctx.Value(&contextKey).(appengine.Context) c, _ := ctx.Value(&contextKey).(appengine.Context)
return c return c
} }
// This is only for classic App Engine adapters. // This is only for classic App Engine adapters.
func ClassicContextFromContext(ctx netcontext.Context) (appengine.Context, error) { func ClassicContextFromContext(ctx context.Context) (appengine.Context, error) {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
return nil, errNotAppEngineContext return nil, errNotAppEngineContext
@ -38,8 +39,8 @@ func ClassicContextFromContext(ctx netcontext.Context) (appengine.Context, error
return c, nil return c, nil
} }
func withContext(parent netcontext.Context, c appengine.Context) netcontext.Context { func withContext(parent context.Context, c appengine.Context) context.Context {
ctx := netcontext.WithValue(parent, &contextKey, c) ctx := context.WithValue(parent, &contextKey, c)
s := &basepb.StringProto{} s := &basepb.StringProto{}
c.Call("__go__", "GetNamespace", &basepb.VoidProto{}, s, nil) c.Call("__go__", "GetNamespace", &basepb.VoidProto{}, s, nil)
@ -50,7 +51,7 @@ func withContext(parent netcontext.Context, c appengine.Context) netcontext.Cont
return ctx return ctx
} }
func IncomingHeaders(ctx netcontext.Context) http.Header { func IncomingHeaders(ctx context.Context) http.Header {
if c := fromContext(ctx); c != nil { if c := fromContext(ctx); c != nil {
if req, ok := c.Request().(*http.Request); ok { if req, ok := c.Request().(*http.Request); ok {
return req.Header return req.Header
@ -59,11 +60,11 @@ func IncomingHeaders(ctx netcontext.Context) http.Header {
return nil return nil
} }
func ReqContext(req *http.Request) netcontext.Context { func ReqContext(req *http.Request) context.Context {
return WithContext(netcontext.Background(), req) return WithContext(context.Background(), req)
} }
func WithContext(parent netcontext.Context, req *http.Request) netcontext.Context { func WithContext(parent context.Context, req *http.Request) context.Context {
c := appengine.NewContext(req) c := appengine.NewContext(req)
return withContext(parent, c) return withContext(parent, c)
} }
@ -83,11 +84,11 @@ func (t *testingContext) Call(service, method string, _, _ appengine_internal.Pr
} }
func (t *testingContext) Request() interface{} { return t.req } func (t *testingContext) Request() interface{} { return t.req }
func ContextForTesting(req *http.Request) netcontext.Context { func ContextForTesting(req *http.Request) context.Context {
return withContext(netcontext.Background(), &testingContext{req: req}) return withContext(context.Background(), &testingContext{req: req})
} }
func Call(ctx netcontext.Context, service, method string, in, out proto.Message) error { func Call(ctx context.Context, service, method string, in, out proto.Message) error {
if ns := NamespaceFromContext(ctx); ns != "" { if ns := NamespaceFromContext(ctx); ns != "" {
if fn, ok := NamespaceMods[service]; ok { if fn, ok := NamespaceMods[service]; ok {
fn(in, ns) fn(in, ns)
@ -144,8 +145,8 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err return err
} }
func handleHTTP(w http.ResponseWriter, r *http.Request) { func Middleware(next http.Handler) http.Handler {
panic("handleHTTP called; this should be impossible") panic("Middleware called; this should be impossible")
} }
func logf(c appengine.Context, level int64, format string, args ...interface{}) { func logf(c appengine.Context, level int64, format string, args ...interface{}) {

View File

@ -5,20 +5,26 @@
package internal package internal
import ( import (
"context"
"errors" "errors"
"os" "os"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"
) )
type ctxKey string
func (c ctxKey) String() string {
return "appengine context key: " + string(c)
}
var errNotAppEngineContext = errors.New("not an App Engine context") var errNotAppEngineContext = errors.New("not an App Engine context")
type CallOverrideFunc func(ctx netcontext.Context, service, method string, in, out proto.Message) error type CallOverrideFunc func(ctx context.Context, service, method string, in, out proto.Message) error
var callOverrideKey = "holds []CallOverrideFunc" var callOverrideKey = "holds []CallOverrideFunc"
func WithCallOverride(ctx netcontext.Context, f CallOverrideFunc) netcontext.Context { func WithCallOverride(ctx context.Context, f CallOverrideFunc) context.Context {
// We avoid appending to any existing call override // We avoid appending to any existing call override
// so we don't risk overwriting a popped stack below. // so we don't risk overwriting a popped stack below.
var cofs []CallOverrideFunc var cofs []CallOverrideFunc
@ -26,10 +32,10 @@ func WithCallOverride(ctx netcontext.Context, f CallOverrideFunc) netcontext.Con
cofs = append(cofs, uf...) cofs = append(cofs, uf...)
} }
cofs = append(cofs, f) cofs = append(cofs, f)
return netcontext.WithValue(ctx, &callOverrideKey, cofs) return context.WithValue(ctx, &callOverrideKey, cofs)
} }
func callOverrideFromContext(ctx netcontext.Context) (CallOverrideFunc, netcontext.Context, bool) { func callOverrideFromContext(ctx context.Context) (CallOverrideFunc, context.Context, bool) {
cofs, _ := ctx.Value(&callOverrideKey).([]CallOverrideFunc) cofs, _ := ctx.Value(&callOverrideKey).([]CallOverrideFunc)
if len(cofs) == 0 { if len(cofs) == 0 {
return nil, nil, false return nil, nil, false
@ -37,7 +43,7 @@ func callOverrideFromContext(ctx netcontext.Context) (CallOverrideFunc, netconte
// We found a list of overrides; grab the last, and reconstitute a // We found a list of overrides; grab the last, and reconstitute a
// context that will hide it. // context that will hide it.
f := cofs[len(cofs)-1] f := cofs[len(cofs)-1]
ctx = netcontext.WithValue(ctx, &callOverrideKey, cofs[:len(cofs)-1]) ctx = context.WithValue(ctx, &callOverrideKey, cofs[:len(cofs)-1])
return f, ctx, true return f, ctx, true
} }
@ -45,23 +51,35 @@ type logOverrideFunc func(level int64, format string, args ...interface{})
var logOverrideKey = "holds a logOverrideFunc" var logOverrideKey = "holds a logOverrideFunc"
func WithLogOverride(ctx netcontext.Context, f logOverrideFunc) netcontext.Context { func WithLogOverride(ctx context.Context, f logOverrideFunc) context.Context {
return netcontext.WithValue(ctx, &logOverrideKey, f) return context.WithValue(ctx, &logOverrideKey, f)
} }
var appIDOverrideKey = "holds a string, being the full app ID" var appIDOverrideKey = "holds a string, being the full app ID"
func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context { func WithAppIDOverride(ctx context.Context, appID string) context.Context {
return netcontext.WithValue(ctx, &appIDOverrideKey, appID) return context.WithValue(ctx, &appIDOverrideKey, appID)
}
var apiHostOverrideKey = ctxKey("holds a string, being the alternate API_HOST")
func withAPIHostOverride(ctx context.Context, apiHost string) context.Context {
return context.WithValue(ctx, apiHostOverrideKey, apiHost)
}
var apiPortOverrideKey = ctxKey("holds a string, being the alternate API_PORT")
func withAPIPortOverride(ctx context.Context, apiPort string) context.Context {
return context.WithValue(ctx, apiPortOverrideKey, apiPort)
} }
var namespaceKey = "holds the namespace string" var namespaceKey = "holds the namespace string"
func withNamespace(ctx netcontext.Context, ns string) netcontext.Context { func withNamespace(ctx context.Context, ns string) context.Context {
return netcontext.WithValue(ctx, &namespaceKey, ns) return context.WithValue(ctx, &namespaceKey, ns)
} }
func NamespaceFromContext(ctx netcontext.Context) string { func NamespaceFromContext(ctx context.Context) string {
// If there's no namespace, return the empty string. // If there's no namespace, return the empty string.
ns, _ := ctx.Value(&namespaceKey).(string) ns, _ := ctx.Value(&namespaceKey).(string)
return ns return ns
@ -70,14 +88,14 @@ func NamespaceFromContext(ctx netcontext.Context) string {
// FullyQualifiedAppID returns the fully-qualified application ID. // FullyQualifiedAppID returns the fully-qualified application ID.
// This may contain a partition prefix (e.g. "s~" for High Replication apps), // This may contain a partition prefix (e.g. "s~" for High Replication apps),
// or a domain prefix (e.g. "example.com:"). // or a domain prefix (e.g. "example.com:").
func FullyQualifiedAppID(ctx netcontext.Context) string { func FullyQualifiedAppID(ctx context.Context) string {
if id, ok := ctx.Value(&appIDOverrideKey).(string); ok { if id, ok := ctx.Value(&appIDOverrideKey).(string); ok {
return id return id
} }
return fullyQualifiedAppID(ctx) return fullyQualifiedAppID(ctx)
} }
func Logf(ctx netcontext.Context, level int64, format string, args ...interface{}) { func Logf(ctx context.Context, level int64, format string, args ...interface{}) {
if f, ok := ctx.Value(&logOverrideKey).(logOverrideFunc); ok { if f, ok := ctx.Value(&logOverrideKey).(logOverrideFunc); ok {
f(level, format, args...) f(level, format, args...)
return return
@ -90,7 +108,7 @@ func Logf(ctx netcontext.Context, level int64, format string, args ...interface{
} }
// NamespacedContext wraps a Context to support namespaces. // NamespacedContext wraps a Context to support namespaces.
func NamespacedContext(ctx netcontext.Context, namespace string) netcontext.Context { func NamespacedContext(ctx context.Context, namespace string) context.Context {
return withNamespace(ctx, namespace) return withNamespace(ctx, namespace)
} }

View File

@ -5,9 +5,8 @@
package internal package internal
import ( import (
"context"
"os" "os"
netcontext "golang.org/x/net/context"
) )
var ( var (
@ -23,7 +22,7 @@ var (
// AppID is the implementation of the wrapper function of the same name in // AppID is the implementation of the wrapper function of the same name in
// ../identity.go. See that file for commentary. // ../identity.go. See that file for commentary.
func AppID(c netcontext.Context) string { func AppID(c context.Context) string {
return appID(FullyQualifiedAppID(c)) return appID(FullyQualifiedAppID(c))
} }
@ -35,7 +34,7 @@ func IsStandard() bool {
return appengineStandard || IsSecondGen() return appengineStandard || IsSecondGen()
} }
// IsStandard is the implementation of the wrapper function of the same name in // IsSecondGen is the implementation of the wrapper function of the same name in
// ../appengine.go. See that file for commentary. // ../appengine.go. See that file for commentary.
func IsSecondGen() bool { func IsSecondGen() bool {
// Second-gen runtimes set $GAE_ENV so we use that to check if we're on a second-gen runtime. // Second-gen runtimes set $GAE_ENV so we use that to check if we're on a second-gen runtime.

View File

@ -2,21 +2,22 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build appengine
// +build appengine // +build appengine
package internal package internal
import ( import (
"appengine" "context"
netcontext "golang.org/x/net/context" "appengine"
) )
func init() { func init() {
appengineStandard = true appengineStandard = true
} }
func DefaultVersionHostname(ctx netcontext.Context) string { func DefaultVersionHostname(ctx context.Context) string {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
panic(errNotAppEngineContext) panic(errNotAppEngineContext)
@ -24,12 +25,12 @@ func DefaultVersionHostname(ctx netcontext.Context) string {
return appengine.DefaultVersionHostname(c) return appengine.DefaultVersionHostname(c)
} }
func Datacenter(_ netcontext.Context) string { return appengine.Datacenter() } func Datacenter(_ context.Context) string { return appengine.Datacenter() }
func ServerSoftware() string { return appengine.ServerSoftware() } func ServerSoftware() string { return appengine.ServerSoftware() }
func InstanceID() string { return appengine.InstanceID() } func InstanceID() string { return appengine.InstanceID() }
func IsDevAppServer() bool { return appengine.IsDevAppServer() } func IsDevAppServer() bool { return appengine.IsDevAppServer() }
func RequestID(ctx netcontext.Context) string { func RequestID(ctx context.Context) string {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
panic(errNotAppEngineContext) panic(errNotAppEngineContext)
@ -37,14 +38,14 @@ func RequestID(ctx netcontext.Context) string {
return appengine.RequestID(c) return appengine.RequestID(c)
} }
func ModuleName(ctx netcontext.Context) string { func ModuleName(ctx context.Context) string {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
panic(errNotAppEngineContext) panic(errNotAppEngineContext)
} }
return appengine.ModuleName(c) return appengine.ModuleName(c)
} }
func VersionID(ctx netcontext.Context) string { func VersionID(ctx context.Context) string {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
panic(errNotAppEngineContext) panic(errNotAppEngineContext)
@ -52,7 +53,7 @@ func VersionID(ctx netcontext.Context) string {
return appengine.VersionID(c) return appengine.VersionID(c)
} }
func fullyQualifiedAppID(ctx netcontext.Context) string { func fullyQualifiedAppID(ctx context.Context) string {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
panic(errNotAppEngineContext) panic(errNotAppEngineContext)

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build appenginevm
// +build appenginevm // +build appenginevm
package internal package internal

View File

@ -2,17 +2,17 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !appengine
// +build !appengine // +build !appengine
package internal package internal
import ( import (
"context"
"log" "log"
"net/http" "net/http"
"os" "os"
"strings" "strings"
netcontext "golang.org/x/net/context"
) )
// These functions are implementations of the wrapper functions // These functions are implementations of the wrapper functions
@ -24,7 +24,7 @@ const (
hDatacenter = "X-AppEngine-Datacenter" hDatacenter = "X-AppEngine-Datacenter"
) )
func ctxHeaders(ctx netcontext.Context) http.Header { func ctxHeaders(ctx context.Context) http.Header {
c := fromContext(ctx) c := fromContext(ctx)
if c == nil { if c == nil {
return nil return nil
@ -32,15 +32,15 @@ func ctxHeaders(ctx netcontext.Context) http.Header {
return c.Request().Header return c.Request().Header
} }
func DefaultVersionHostname(ctx netcontext.Context) string { func DefaultVersionHostname(ctx context.Context) string {
return ctxHeaders(ctx).Get(hDefaultVersionHostname) return ctxHeaders(ctx).Get(hDefaultVersionHostname)
} }
func RequestID(ctx netcontext.Context) string { func RequestID(ctx context.Context) string {
return ctxHeaders(ctx).Get(hRequestLogId) return ctxHeaders(ctx).Get(hRequestLogId)
} }
func Datacenter(ctx netcontext.Context) string { func Datacenter(ctx context.Context) string {
if dc := ctxHeaders(ctx).Get(hDatacenter); dc != "" { if dc := ctxHeaders(ctx).Get(hDatacenter); dc != "" {
return dc return dc
} }
@ -71,7 +71,7 @@ func ServerSoftware() string {
// TODO(dsymonds): Remove the metadata fetches. // TODO(dsymonds): Remove the metadata fetches.
func ModuleName(_ netcontext.Context) string { func ModuleName(_ context.Context) string {
if s := os.Getenv("GAE_MODULE_NAME"); s != "" { if s := os.Getenv("GAE_MODULE_NAME"); s != "" {
return s return s
} }
@ -81,7 +81,7 @@ func ModuleName(_ netcontext.Context) string {
return string(mustGetMetadata("instance/attributes/gae_backend_name")) return string(mustGetMetadata("instance/attributes/gae_backend_name"))
} }
func VersionID(_ netcontext.Context) string { func VersionID(_ context.Context) string {
if s1, s2 := os.Getenv("GAE_MODULE_VERSION"), os.Getenv("GAE_MINOR_VERSION"); s1 != "" && s2 != "" { if s1, s2 := os.Getenv("GAE_MODULE_VERSION"), os.Getenv("GAE_MINOR_VERSION"); s1 != "" && s2 != "" {
return s1 + "." + s2 return s1 + "." + s2
} }
@ -112,7 +112,7 @@ func partitionlessAppID() string {
return string(mustGetMetadata("instance/attributes/gae_project")) return string(mustGetMetadata("instance/attributes/gae_project"))
} }
func fullyQualifiedAppID(_ netcontext.Context) string { func fullyQualifiedAppID(_ context.Context) string {
if s := os.Getenv("GAE_APPLICATION"); s != "" { if s := os.Getenv("GAE_APPLICATION"); s != "" {
return s return s
} }
@ -130,5 +130,5 @@ func fullyQualifiedAppID(_ netcontext.Context) string {
} }
func IsDevAppServer() bool { func IsDevAppServer() bool {
return os.Getenv("RUN_WITH_DEVAPPSERVER") != "" return os.Getenv("RUN_WITH_DEVAPPSERVER") != "" || os.Getenv("GAE_ENV") == "localdev"
} }

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build appengine
// +build appengine // +build appengine
package internal package internal

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by the Apache 2.0 // Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !appengine
// +build !appengine // +build !appengine
package internal package internal
@ -29,7 +30,7 @@ func Main() {
if IsDevAppServer() { if IsDevAppServer() {
host = "127.0.0.1" host = "127.0.0.1"
} }
if err := http.ListenAndServe(host+":"+port, http.HandlerFunc(handleHTTP)); err != nil { if err := http.ListenAndServe(host+":"+port, Middleware(http.DefaultServeMux)); err != nil {
log.Fatalf("http.ListenAndServe: %v", err) log.Fatalf("http.ListenAndServe: %v", err)
} }
} }

View File

@ -7,11 +7,11 @@ package internal
// This file implements hooks for applying datastore transactions. // This file implements hooks for applying datastore transactions.
import ( import (
"context"
"errors" "errors"
"reflect" "reflect"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"
basepb "google.golang.org/appengine/internal/base" basepb "google.golang.org/appengine/internal/base"
pb "google.golang.org/appengine/internal/datastore" pb "google.golang.org/appengine/internal/datastore"
@ -38,13 +38,13 @@ func applyTransaction(pb proto.Message, t *pb.Transaction) {
var transactionKey = "used for *Transaction" var transactionKey = "used for *Transaction"
func transactionFromContext(ctx netcontext.Context) *transaction { func transactionFromContext(ctx context.Context) *transaction {
t, _ := ctx.Value(&transactionKey).(*transaction) t, _ := ctx.Value(&transactionKey).(*transaction)
return t return t
} }
func withTransaction(ctx netcontext.Context, t *transaction) netcontext.Context { func withTransaction(ctx context.Context, t *transaction) context.Context {
return netcontext.WithValue(ctx, &transactionKey, t) return context.WithValue(ctx, &transactionKey, t)
} }
type transaction struct { type transaction struct {
@ -54,7 +54,7 @@ type transaction struct {
var ErrConcurrentTransaction = errors.New("internal: concurrent transaction") var ErrConcurrentTransaction = errors.New("internal: concurrent transaction")
func RunTransactionOnce(c netcontext.Context, f func(netcontext.Context) error, xg bool, readOnly bool, previousTransaction *pb.Transaction) (*pb.Transaction, error) { func RunTransactionOnce(c context.Context, f func(context.Context) error, xg bool, readOnly bool, previousTransaction *pb.Transaction) (*pb.Transaction, error) {
if transactionFromContext(c) != nil { if transactionFromContext(c) != nil {
return nil, errors.New("nested transactions are not supported") return nil, errors.New("nested transactions are not supported")
} }

View File

@ -7,6 +7,7 @@
package urlfetch // import "google.golang.org/appengine/urlfetch" package urlfetch // import "google.golang.org/appengine/urlfetch"
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -18,7 +19,6 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/appengine/internal" "google.golang.org/appengine/internal"
pb "google.golang.org/appengine/internal/urlfetch" pb "google.golang.org/appengine/internal/urlfetch"
@ -44,11 +44,10 @@ type Transport struct {
var _ http.RoundTripper = (*Transport)(nil) var _ http.RoundTripper = (*Transport)(nil)
// Client returns an *http.Client using a default urlfetch Transport. This // Client returns an *http.Client using a default urlfetch Transport. This
// client will have the default deadline of 5 seconds, and will check the // client will check the validity of SSL certificates.
// validity of SSL certificates.
// //
// Any deadline of the provided context will be used for requests through this client; // Any deadline of the provided context will be used for requests through this client.
// if the client does not have a deadline then a 5 second default is used. // If the client does not have a deadline, then an App Engine default of 60 second is used.
func Client(ctx context.Context) *http.Client { func Client(ctx context.Context) *http.Client {
return &http.Client{ return &http.Client{
Transport: &Transport{ Transport: &Transport{

View File

@ -32,21 +32,13 @@ import (
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
type ccbMode int
const (
ccbModeActive = iota
ccbModeIdle
ccbModeClosed
ccbModeExitingIdle
)
// ccBalancerWrapper sits between the ClientConn and the Balancer. // ccBalancerWrapper sits between the ClientConn and the Balancer.
// //
// ccBalancerWrapper implements methods corresponding to the ones on the // ccBalancerWrapper implements methods corresponding to the ones on the
// balancer.Balancer interface. The ClientConn is free to call these methods // balancer.Balancer interface. The ClientConn is free to call these methods
// concurrently and the ccBalancerWrapper ensures that calls from the ClientConn // concurrently and the ccBalancerWrapper ensures that calls from the ClientConn
// to the Balancer happen synchronously and in order. // to the Balancer happen in order by performing them in the serializer, without
// any mutexes held.
// //
// ccBalancerWrapper also implements the balancer.ClientConn interface and is // ccBalancerWrapper also implements the balancer.ClientConn interface and is
// passed to the Balancer implementations. It invokes unexported methods on the // passed to the Balancer implementations. It invokes unexported methods on the
@ -57,87 +49,75 @@ const (
type ccBalancerWrapper struct { type ccBalancerWrapper struct {
// The following fields are initialized when the wrapper is created and are // The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex. // read-only afterwards, and therefore can be accessed without a mutex.
cc *ClientConn cc *ClientConn
opts balancer.BuildOptions opts balancer.BuildOptions
serializer *grpcsync.CallbackSerializer
serializerCancel context.CancelFunc
// Outgoing (gRPC --> balancer) calls are guaranteed to execute in a // The following fields are only accessed within the serializer or during
// mutually exclusive manner as they are scheduled in the serializer. Fields // initialization.
// accessed *only* in these serializer callbacks, can therefore be accessed
// without a mutex.
balancer *gracefulswitch.Balancer
curBalancerName string curBalancerName string
balancer *gracefulswitch.Balancer
// mu guards access to the below fields. Access to the serializer and its // The following field is protected by mu. Caller must take cc.mu before
// cancel function needs to be mutex protected because they are overwritten // taking mu.
// when the wrapper exits idle mode. mu sync.Mutex
mu sync.Mutex closed bool
serializer *grpcsync.CallbackSerializer // To serialize all outoing calls.
serializerCancel context.CancelFunc // To close the seralizer at close/enterIdle time.
mode ccbMode // Tracks the current mode of the wrapper.
} }
// newCCBalancerWrapper creates a new balancer wrapper. The underlying balancer // newCCBalancerWrapper creates a new balancer wrapper in idle state. The
// is not created until the switchTo() method is invoked. // underlying balancer is not created until the switchTo() method is invoked.
func newCCBalancerWrapper(cc *ClientConn, bopts balancer.BuildOptions) *ccBalancerWrapper { func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(cc.ctx)
ccb := &ccBalancerWrapper{ ccb := &ccBalancerWrapper{
cc: cc, cc: cc,
opts: bopts, opts: balancer.BuildOptions{
DialCreds: cc.dopts.copts.TransportCredentials,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
},
serializer: grpcsync.NewCallbackSerializer(ctx), serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel, serializerCancel: cancel,
} }
ccb.balancer = gracefulswitch.NewBalancer(ccb, bopts) ccb.balancer = gracefulswitch.NewBalancer(ccb, ccb.opts)
return ccb return ccb
} }
// updateClientConnState is invoked by grpc to push a ClientConnState update to // updateClientConnState is invoked by grpc to push a ClientConnState update to
// the underlying balancer. // the underlying balancer. This is always executed from the serializer, so
// it is safe to call into the balancer here.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error { func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
ccb.mu.Lock() errCh := make(chan error)
errCh := make(chan error, 1) ok := ccb.serializer.Schedule(func(ctx context.Context) {
// Here and everywhere else where Schedule() is called, it is done with the defer close(errCh)
// lock held. But the lock guards only the scheduling part. The actual if ctx.Err() != nil || ccb.balancer == nil {
// callback is called asynchronously without the lock being held. return
ok := ccb.serializer.Schedule(func(_ context.Context) { }
errCh <- ccb.balancer.UpdateClientConnState(*ccs) err := ccb.balancer.UpdateClientConnState(*ccs)
if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
}
errCh <- err
}) })
if !ok { if !ok {
// If we are unable to schedule a function with the serializer, it return nil
// indicates that it has been closed. A serializer is only closed when
// the wrapper is closed or is in idle.
ccb.mu.Unlock()
return fmt.Errorf("grpc: cannot send state update to a closed or idle balancer")
} }
ccb.mu.Unlock() return <-errCh
// We get here only if the above call to Schedule succeeds, in which case it
// is guaranteed that the scheduled function will run. Therefore it is safe
// to block on this channel.
err := <-errCh
if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
}
return err
}
// updateSubConnState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) {
ccb.mu.Lock()
ccb.serializer.Schedule(func(_ context.Context) {
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
sc.(*acBalancerWrapper).stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
})
ccb.mu.Unlock()
} }
// resolverError is invoked by grpc to push a resolver error to the underlying
// balancer. The call to the balancer is executed from the serializer.
func (ccb *ccBalancerWrapper) resolverError(err error) { func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.mu.Lock() ccb.serializer.Schedule(func(ctx context.Context) {
ccb.serializer.Schedule(func(_ context.Context) { if ctx.Err() != nil || ccb.balancer == nil {
return
}
ccb.balancer.ResolverError(err) ccb.balancer.ResolverError(err)
}) })
ccb.mu.Unlock()
} }
// switchTo is invoked by grpc to instruct the balancer wrapper to switch to the // switchTo is invoked by grpc to instruct the balancer wrapper to switch to the
@ -151,8 +131,10 @@ func (ccb *ccBalancerWrapper) resolverError(err error) {
// the ccBalancerWrapper keeps track of the current LB policy name, and skips // the ccBalancerWrapper keeps track of the current LB policy name, and skips
// the graceful balancer switching process if the name does not change. // the graceful balancer switching process if the name does not change.
func (ccb *ccBalancerWrapper) switchTo(name string) { func (ccb *ccBalancerWrapper) switchTo(name string) {
ccb.mu.Lock() ccb.serializer.Schedule(func(ctx context.Context) {
ccb.serializer.Schedule(func(_ context.Context) { if ctx.Err() != nil || ccb.balancer == nil {
return
}
// TODO: Other languages use case-sensitive balancer registries. We should // TODO: Other languages use case-sensitive balancer registries. We should
// switch as well. See: https://github.com/grpc/grpc-go/issues/5288. // switch as well. See: https://github.com/grpc/grpc-go/issues/5288.
if strings.EqualFold(ccb.curBalancerName, name) { if strings.EqualFold(ccb.curBalancerName, name) {
@ -160,7 +142,6 @@ func (ccb *ccBalancerWrapper) switchTo(name string) {
} }
ccb.buildLoadBalancingPolicy(name) ccb.buildLoadBalancingPolicy(name)
}) })
ccb.mu.Unlock()
} }
// buildLoadBalancingPolicy performs the following: // buildLoadBalancingPolicy performs the following:
@ -187,115 +168,49 @@ func (ccb *ccBalancerWrapper) buildLoadBalancingPolicy(name string) {
ccb.curBalancerName = builder.Name() ccb.curBalancerName = builder.Name()
} }
// close initiates async shutdown of the wrapper. cc.mu must be held when
// calling this function. To determine the wrapper has finished shutting down,
// the channel should block on ccb.serializer.Done() without cc.mu held.
func (ccb *ccBalancerWrapper) close() { func (ccb *ccBalancerWrapper) close() {
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: closing")
ccb.closeBalancer(ccbModeClosed)
}
// enterIdleMode is invoked by grpc when the channel enters idle mode upon
// expiry of idle_timeout. This call blocks until the balancer is closed.
func (ccb *ccBalancerWrapper) enterIdleMode() {
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: entering idle mode")
ccb.closeBalancer(ccbModeIdle)
}
// closeBalancer is invoked when the channel is being closed or when it enters
// idle mode upon expiry of idle_timeout.
func (ccb *ccBalancerWrapper) closeBalancer(m ccbMode) {
ccb.mu.Lock() ccb.mu.Lock()
if ccb.mode == ccbModeClosed || ccb.mode == ccbModeIdle { ccb.closed = true
ccb.mu.Unlock()
return
}
ccb.mode = m
done := ccb.serializer.Done()
b := ccb.balancer
ok := ccb.serializer.Schedule(func(_ context.Context) {
// Close the serializer to ensure that no more calls from gRPC are sent
// to the balancer.
ccb.serializerCancel()
// Empty the current balancer name because we don't have a balancer
// anymore and also so that we act on the next call to switchTo by
// creating a new balancer specified by the new resolver.
ccb.curBalancerName = ""
})
if !ok {
ccb.mu.Unlock()
return
}
ccb.mu.Unlock() ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: closing")
// Give enqueued callbacks a chance to finish before closing the balancer. ccb.serializer.Schedule(func(context.Context) {
<-done if ccb.balancer == nil {
b.Close()
}
// exitIdleMode is invoked by grpc when the channel exits idle mode either
// because of an RPC or because of an invocation of the Connect() API. This
// recreates the balancer that was closed previously when entering idle mode.
//
// If the channel is not in idle mode, we know for a fact that we are here as a
// result of the user calling the Connect() method on the ClientConn. In this
// case, we can simply forward the call to the underlying balancer, instructing
// it to reconnect to the backends.
func (ccb *ccBalancerWrapper) exitIdleMode() {
ccb.mu.Lock()
if ccb.mode == ccbModeClosed {
// Request to exit idle is a no-op when wrapper is already closed.
ccb.mu.Unlock()
return
}
if ccb.mode == ccbModeIdle {
// Recreate the serializer which was closed when we entered idle.
ctx, cancel := context.WithCancel(context.Background())
ccb.serializer = grpcsync.NewCallbackSerializer(ctx)
ccb.serializerCancel = cancel
}
// The ClientConn guarantees that mutual exclusion between close() and
// exitIdleMode(), and since we just created a new serializer, we can be
// sure that the below function will be scheduled.
done := make(chan struct{})
ccb.serializer.Schedule(func(_ context.Context) {
defer close(done)
ccb.mu.Lock()
defer ccb.mu.Unlock()
if ccb.mode != ccbModeIdle {
ccb.balancer.ExitIdle()
return return
} }
ccb.balancer.Close()
// Gracefulswitch balancer does not support a switchTo operation after ccb.balancer = nil
// being closed. Hence we need to create a new one here.
ccb.balancer = gracefulswitch.NewBalancer(ccb, ccb.opts)
ccb.mode = ccbModeActive
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: exiting idle mode")
}) })
ccb.mu.Unlock() ccb.serializerCancel()
<-done
} }
func (ccb *ccBalancerWrapper) isIdleOrClosed() bool { // exitIdle invokes the balancer's exitIdle method in the serializer.
ccb.mu.Lock() func (ccb *ccBalancerWrapper) exitIdle() {
defer ccb.mu.Unlock() ccb.serializer.Schedule(func(ctx context.Context) {
return ccb.mode == ccbModeIdle || ccb.mode == ccbModeClosed if ctx.Err() != nil || ccb.balancer == nil {
return
}
ccb.balancer.ExitIdle()
})
} }
func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
if ccb.isIdleOrClosed() { ccb.cc.mu.Lock()
return nil, fmt.Errorf("grpc: cannot create SubConn when balancer is closed or idle") defer ccb.cc.mu.Unlock()
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return nil, fmt.Errorf("balancer is being closed; no new SubConns allowed")
} }
ccb.mu.Unlock()
if len(addrs) == 0 { if len(addrs) == 0 {
return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list") return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list")
} }
ac, err := ccb.cc.newAddrConn(addrs, opts) ac, err := ccb.cc.newAddrConnLocked(addrs, opts)
if err != nil { if err != nil {
channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err) channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err return nil, err
@ -316,10 +231,6 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) {
} }
func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
if ccb.isIdleOrClosed() {
return
}
acbw, ok := sc.(*acBalancerWrapper) acbw, ok := sc.(*acBalancerWrapper)
if !ok { if !ok {
return return
@ -328,25 +239,39 @@ func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resol
} }
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) { func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
if ccb.isIdleOrClosed() { ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock()
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return return
} }
ccb.mu.Unlock()
// Update picker before updating state. Even though the ordering here does // Update picker before updating state. Even though the ordering here does
// not matter, it can lead to multiple calls of Pick in the common start-up // not matter, it can lead to multiple calls of Pick in the common start-up
// case where we wait for ready and then perform an RPC. If the picker is // case where we wait for ready and then perform an RPC. If the picker is
// updated later, we could call the "connecting" picker when the state is // updated later, we could call the "connecting" picker when the state is
// updated, and then call the "ready" picker after the picker gets updated. // updated, and then call the "ready" picker after the picker gets updated.
ccb.cc.blockingpicker.updatePicker(s.Picker)
// Note that there is no need to check if the balancer wrapper was closed,
// as we know the graceful switch LB policy will not call cc if it has been
// closed.
ccb.cc.pickerWrapper.updatePicker(s.Picker)
ccb.cc.csMgr.updateState(s.ConnectivityState) ccb.cc.csMgr.updateState(s.ConnectivityState)
} }
func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOptions) { func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOptions) {
if ccb.isIdleOrClosed() { ccb.cc.mu.RLock()
defer ccb.cc.mu.RUnlock()
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return return
} }
ccb.mu.Unlock()
ccb.cc.resolveNow(o) ccb.cc.resolveNowLocked(o)
} }
func (ccb *ccBalancerWrapper) Target() string { func (ccb *ccBalancerWrapper) Target() string {
@ -364,6 +289,20 @@ type acBalancerWrapper struct {
producers map[balancer.ProducerBuilder]*refCountedProducer producers map[balancer.ProducerBuilder]*refCountedProducer
} }
// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
})
}
func (acbw *acBalancerWrapper) String() string { func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int()) return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int())
} }
@ -377,20 +316,7 @@ func (acbw *acBalancerWrapper) Connect() {
} }
func (acbw *acBalancerWrapper) Shutdown() { func (acbw *acBalancerWrapper) Shutdown() {
ccb := acbw.ccb acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
if ccb.isIdleOrClosed() {
// It it safe to ignore this call when the balancer is closed or in idle
// because the ClientConn takes care of closing the connections.
//
// Not returning early from here when the balancer is closed or in idle
// leads to a deadlock though, because of the following sequence of
// calls when holding cc.mu:
// cc.exitIdleMode --> ccb.enterIdleMode --> gsw.Close -->
// ccb.RemoveAddrConn --> cc.removeAddrConn
return
}
ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
} }
// NewStream begins a streaming RPC on the addrConn. If the addrConn is not // NewStream begins a streaming RPC on the addrConn. If the addrConn is not

View File

@ -33,9 +33,7 @@ import (
"google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle" "google.golang.org/grpc/internal/idle"
@ -48,9 +46,9 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
_ "google.golang.org/grpc/internal/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver. _ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver.
_ "google.golang.org/grpc/internal/resolver/unix" // To register unix resolver. _ "google.golang.org/grpc/internal/resolver/unix" // To register unix resolver.
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
) )
const ( const (
@ -119,23 +117,8 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires
}, nil }, nil
} }
// DialContext creates a client connection to the given target. By default, it's // newClient returns a new client in idle mode.
// a non-blocking dial (the function won't wait for connections to be func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
// established, and connecting happens in the background). To make it a blocking
// dial, use WithBlock() dial option.
//
// In the non-blocking case, the ctx does not act against the connection. It
// only controls the setup steps.
//
// In the blocking case, ctx can be used to cancel or expire the pending
// connection. Once this function returns, the cancellation and expiration of
// ctx will be noop. Users should call ClientConn.Close to terminate all the
// pending operations after this function returns.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[*addrConn]struct{}), conns: make(map[*addrConn]struct{}),
@ -143,23 +126,11 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
czData: new(channelzData), czData: new(channelzData),
} }
// We start the channel off in idle mode, but kick it out of idle at the end
// of this method, instead of waiting for the first RPC. Other gRPC
// implementations do wait for the first RPC to kick the channel out of
// idle. But doing so would be a major behavior change for our users who are
// used to seeing the channel active after Dial.
//
// Taking this approach of kicking it out of idle at the end of this method
// allows us to share the code between channel creation and exiting idle
// mode. This will also make it easy for us to switch to starting the
// channel off in idle, if at all we ever get to do that.
cc.idlenessState = ccIdlenessStateIdle
cc.retryThrottler.Store((*retryThrottler)(nil)) cc.retryThrottler.Store((*retryThrottler)(nil))
cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil}) cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil})
cc.ctx, cc.cancel = context.WithCancel(context.Background()) cc.ctx, cc.cancel = context.WithCancel(context.Background())
cc.exitIdleCond = sync.NewCond(&cc.mu)
// Apply dial options.
disableGlobalOpts := false disableGlobalOpts := false
for _, opt := range opts { for _, opt := range opts {
if _, ok := opt.(*disableGlobalDialOptions); ok { if _, ok := opt.(*disableGlobalDialOptions); ok {
@ -177,21 +148,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
for _, opt := range opts { for _, opt := range opts {
opt.apply(&cc.dopts) opt.apply(&cc.dopts)
} }
chainUnaryClientInterceptors(cc) chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc) chainStreamClientInterceptors(cc)
defer func() {
if err != nil {
cc.Close()
}
}()
// Register ClientConn with channelz.
cc.channelzRegistration(target)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelzID)
if err := cc.validateTransportCredentials(); err != nil { if err := cc.validateTransportCredentials(); err != nil {
return nil, err return nil, err
} }
@ -205,10 +164,80 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
} }
cc.mkp = cc.dopts.copts.KeepaliveParams cc.mkp = cc.dopts.copts.KeepaliveParams
if cc.dopts.copts.UserAgent != "" { // Register ClientConn with channelz.
cc.dopts.copts.UserAgent += " " + grpcUA cc.channelzRegistration(target)
} else {
cc.dopts.copts.UserAgent = grpcUA // TODO: Ideally it should be impossible to error from this function after
// channelz registration. This will require removing some channelz logs
// from the following functions that can error. Errors can be returned to
// the user, and successful logs can be emitted here, after the checks have
// passed and channelz is subsequently registered.
// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
if err = cc.determineAuthority(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelzID)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
return cc, nil
}
// DialContext creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
// established, and connecting happens in the background). To make it a blocking
// dial, use WithBlock() dial option.
//
// In the non-blocking case, the ctx does not act against the connection. It
// only controls the setup steps.
//
// In the blocking case, ctx can be used to cancel or expire the pending
// connection. Once this function returns, the cancellation and expiration of
// ctx will be noop. Users should call ClientConn.Close to terminate all the
// pending operations after this function returns.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc, err := newClient(target, opts...)
if err != nil {
return nil, err
}
// We start the channel off in idle mode, but kick it out of idle now,
// instead of waiting for the first RPC. Other gRPC implementations do wait
// for the first RPC to kick the channel out of idle. But doing so would be
// a major behavior change for our users who are used to seeing the channel
// active after Dial.
//
// Taking this approach of kicking it out of idle at the end of this method
// allows us to share the code between channel creation and exiting idle
// mode. This will also make it easy for us to switch to starting the
// channel off in idle, i.e. by making newClient exported.
defer func() {
if err != nil {
cc.Close()
}
}()
// This creates the name resolver, load balancer, etc.
if err := cc.idlenessMgr.ExitIdleMode(); err != nil {
return nil, err
}
// Return now for non-blocking dials.
if !cc.dopts.block {
return cc, nil
} }
if cc.dopts.timeout > 0 { if cc.dopts.timeout > 0 {
@ -231,49 +260,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
} }
}() }()
if cc.dopts.bs == nil {
cc.dopts.bs = backoff.DefaultExponential
}
// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
return nil, err
}
if err = cc.determineAuthority(); err != nil {
return nil, err
}
if cc.dopts.scChan != nil {
// Blocking wait for the initial service config.
select {
case sc, ok := <-cc.dopts.scChan:
if ok {
cc.sc = &sc
cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc})
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
if cc.dopts.scChan != nil {
go cc.scWatcher()
}
// This creates the name resolver, load balancer, blocking picker etc.
if err := cc.exitIdleMode(); err != nil {
return nil, err
}
// Configure idleness support with configured idle timeout or default idle
// timeout duration. Idleness can be explicitly disabled by the user, by
// setting the dial option to 0.
cc.idlenessMgr = idle.NewManager(idle.ManagerOptions{Enforcer: (*idler)(cc), Timeout: cc.dopts.idleTimeout, Logger: logger})
// Return early for non-blocking dials.
if !cc.dopts.block {
return cc, nil
}
// A blocking dial blocks until the clientConn is ready. // A blocking dial blocks until the clientConn is ready.
for { for {
s := cc.GetState() s := cc.GetState()
@ -320,8 +306,8 @@ func (cc *ClientConn) addTraceEvent(msg string) {
type idler ClientConn type idler ClientConn
func (i *idler) EnterIdleMode() error { func (i *idler) EnterIdleMode() {
return (*ClientConn)(i).enterIdleMode() (*ClientConn)(i).enterIdleMode()
} }
func (i *idler) ExitIdleMode() error { func (i *idler) ExitIdleMode() error {
@ -329,117 +315,71 @@ func (i *idler) ExitIdleMode() error {
} }
// exitIdleMode moves the channel out of idle mode by recreating the name // exitIdleMode moves the channel out of idle mode by recreating the name
// resolver and load balancer. // resolver and load balancer. This should never be called directly; use
func (cc *ClientConn) exitIdleMode() error { // cc.idlenessMgr.ExitIdleMode instead.
func (cc *ClientConn) exitIdleMode() (err error) {
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
return errConnClosing return errConnClosing
} }
if cc.idlenessState != ccIdlenessStateIdle {
channelz.Infof(logger, cc.channelzID, "ClientConn asked to exit idle mode, current mode is %v", cc.idlenessState)
cc.mu.Unlock()
return nil
}
defer func() {
// When Close() and exitIdleMode() race against each other, one of the
// following two can happen:
// - Close() wins the race and runs first. exitIdleMode() runs after, and
// sees that the ClientConn is already closed and hence returns early.
// - exitIdleMode() wins the race and runs first and recreates the balancer
// and releases the lock before recreating the resolver. If Close() runs
// in this window, it will wait for exitIdleMode to complete.
//
// We achieve this synchronization using the below condition variable.
cc.mu.Lock()
cc.idlenessState = ccIdlenessStateActive
cc.exitIdleCond.Signal()
cc.mu.Unlock()
}()
cc.idlenessState = ccIdlenessStateExitingIdle
exitedIdle := false
if cc.blockingpicker == nil {
cc.blockingpicker = newPickerWrapper(cc.dopts.copts.StatsHandlers)
} else {
cc.blockingpicker.exitIdleMode()
exitedIdle = true
}
var credsClone credentials.TransportCredentials
if creds := cc.dopts.copts.TransportCredentials; creds != nil {
credsClone = creds.Clone()
}
if cc.balancerWrapper == nil {
cc.balancerWrapper = newCCBalancerWrapper(cc, balancer.BuildOptions{
DialCreds: credsClone,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
})
} else {
cc.balancerWrapper.exitIdleMode()
}
cc.firstResolveEvent = grpcsync.NewEvent()
cc.mu.Unlock() cc.mu.Unlock()
// This needs to be called without cc.mu because this builds a new resolver // This needs to be called without cc.mu because this builds a new resolver
// which might update state or report error inline which needs to be handled // which might update state or report error inline, which would then need to
// by cc.updateResolverState() which also grabs cc.mu. // acquire cc.mu.
if err := cc.initResolverWrapper(credsClone); err != nil { if err := cc.resolverWrapper.start(); err != nil {
return err return err
} }
if exitedIdle { cc.addTraceEvent("exiting idle mode")
cc.addTraceEvent("exiting idle mode")
}
return nil return nil
} }
// enterIdleMode puts the channel in idle mode, and as part of it shuts down the // initIdleStateLocked initializes common state to how it should be while idle.
// name resolver, load balancer and any subchannels. func (cc *ClientConn) initIdleStateLocked() {
func (cc *ClientConn) enterIdleMode() error { cc.resolverWrapper = newCCResolverWrapper(cc)
cc.mu.Lock() cc.balancerWrapper = newCCBalancerWrapper(cc)
defer cc.mu.Unlock() cc.firstResolveEvent = grpcsync.NewEvent()
if cc.conns == nil {
return ErrClientConnClosing
}
if cc.idlenessState != ccIdlenessStateActive {
channelz.Warningf(logger, cc.channelzID, "ClientConn asked to enter idle mode, current mode is %v", cc.idlenessState)
return nil
}
// cc.conns == nil is a proxy for the ClientConn being closed. So, instead // cc.conns == nil is a proxy for the ClientConn being closed. So, instead
// of setting it to nil here, we recreate the map. This also means that we // of setting it to nil here, we recreate the map. This also means that we
// don't have to do this when exiting idle mode. // don't have to do this when exiting idle mode.
conns := cc.conns
cc.conns = make(map[*addrConn]struct{}) cc.conns = make(map[*addrConn]struct{})
}
// TODO: Currently, we close the resolver wrapper upon entering idle mode // enterIdleMode puts the channel in idle mode, and as part of it shuts down the
// and create a new one upon exiting idle mode. This means that the // name resolver, load balancer, and any subchannels. This should never be
// `cc.resolverWrapper` field would be overwritten everytime we exit idle // called directly; use cc.idlenessMgr.EnterIdleMode instead.
// mode. While this means that we need to hold `cc.mu` when accessing func (cc *ClientConn) enterIdleMode() {
// `cc.resolverWrapper`, it makes the code simpler in the wrapper. We should cc.mu.Lock()
// try to do the same for the balancer and picker wrappers too.
cc.resolverWrapper.close() if cc.conns == nil {
cc.blockingpicker.enterIdleMode() cc.mu.Unlock()
cc.balancerWrapper.enterIdleMode() return
}
conns := cc.conns
rWrapper := cc.resolverWrapper
rWrapper.close()
cc.pickerWrapper.reset()
bWrapper := cc.balancerWrapper
bWrapper.close()
cc.csMgr.updateState(connectivity.Idle) cc.csMgr.updateState(connectivity.Idle)
cc.idlenessState = ccIdlenessStateIdle
cc.addTraceEvent("entering idle mode") cc.addTraceEvent("entering idle mode")
go func() { cc.initIdleStateLocked()
for ac := range conns {
ac.tearDown(errConnIdling)
}
}()
return nil cc.mu.Unlock()
// Block until the name resolver and LB policy are closed.
<-rWrapper.serializer.Done()
<-bWrapper.serializer.Done()
// Close all subchannels after the LB policy is closed.
for ac := range conns {
ac.tearDown(errConnIdling)
}
} }
// validateTransportCredentials performs a series of checks on the configured // validateTransportCredentials performs a series of checks on the configured
@ -649,66 +589,35 @@ type ClientConn struct {
dopts dialOptions // Default and user specified dial options. dopts dialOptions // Default and user specified dial options.
channelzID *channelz.Identifier // Channelz identifier for the channel. channelzID *channelz.Identifier // Channelz identifier for the channel.
resolverBuilder resolver.Builder // See parseTargetAndFindResolver(). resolverBuilder resolver.Builder // See parseTargetAndFindResolver().
balancerWrapper *ccBalancerWrapper // Uses gracefulswitch.balancer underneath. idlenessMgr *idle.Manager
idlenessMgr idle.Manager
// The following provide their own synchronization, and therefore don't // The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them. // require cc.mu to be held to access them.
csMgr *connectivityStateManager csMgr *connectivityStateManager
blockingpicker *pickerWrapper pickerWrapper *pickerWrapper
safeConfigSelector iresolver.SafeConfigSelector safeConfigSelector iresolver.SafeConfigSelector
czData *channelzData czData *channelzData
retryThrottler atomic.Value // Updated from service config. retryThrottler atomic.Value // Updated from service config.
// firstResolveEvent is used to track whether the name resolver sent us at
// least one update. RPCs block on this event.
firstResolveEvent *grpcsync.Event
// mu protects the following fields. // mu protects the following fields.
// TODO: split mu so the same mutex isn't used for everything. // TODO: split mu so the same mutex isn't used for everything.
mu sync.RWMutex mu sync.RWMutex
resolverWrapper *ccResolverWrapper // Initialized in Dial; cleared in Close. resolverWrapper *ccResolverWrapper // Always recreated whenever entering idle to simplify Close.
balancerWrapper *ccBalancerWrapper // Always recreated whenever entering idle to simplify Close.
sc *ServiceConfig // Latest service config received from the resolver. sc *ServiceConfig // Latest service config received from the resolver.
conns map[*addrConn]struct{} // Set to nil on close. conns map[*addrConn]struct{} // Set to nil on close.
mkp keepalive.ClientParameters // May be updated upon receipt of a GoAway. mkp keepalive.ClientParameters // May be updated upon receipt of a GoAway.
idlenessState ccIdlenessState // Tracks idleness state of the channel. // firstResolveEvent is used to track whether the name resolver sent us at
exitIdleCond *sync.Cond // Signalled when channel exits idle. // least one update. RPCs block on this event. May be accessed without mu
// if we know we cannot be asked to enter idle mode while accessing it (e.g.
// when the idle manager has already been closed, or if we are already
// entering idle mode).
firstResolveEvent *grpcsync.Event
lceMu sync.Mutex // protects lastConnectionError lceMu sync.Mutex // protects lastConnectionError
lastConnectionError error lastConnectionError error
} }
// ccIdlenessState tracks the idleness state of the channel.
//
// Channels start off in `active` and move to `idle` after a period of
// inactivity. When moving back to `active` upon an incoming RPC, they
// transition through `exiting_idle`. This state is useful for synchronization
// with Close().
//
// This state tracking is mostly for self-protection. The idlenessManager is
// expected to keep track of the state as well, and is expected not to call into
// the ClientConn unnecessarily.
type ccIdlenessState int8
const (
ccIdlenessStateActive ccIdlenessState = iota
ccIdlenessStateIdle
ccIdlenessStateExitingIdle
)
func (s ccIdlenessState) String() string {
switch s {
case ccIdlenessStateActive:
return "active"
case ccIdlenessStateIdle:
return "idle"
case ccIdlenessStateExitingIdle:
return "exitingIdle"
default:
return "unknown"
}
}
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
// ctx expires. A true value is returned in former case and false in latter. // ctx expires. A true value is returned in former case and false in latter.
// //
@ -748,29 +657,15 @@ func (cc *ClientConn) GetState() connectivity.State {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later // Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release. // release.
func (cc *ClientConn) Connect() { func (cc *ClientConn) Connect() {
cc.exitIdleMode() if err := cc.idlenessMgr.ExitIdleMode(); err != nil {
cc.addTraceEvent(err.Error())
return
}
// If the ClientConn was not in idle mode, we need to call ExitIdle on the // If the ClientConn was not in idle mode, we need to call ExitIdle on the
// LB policy so that connections can be created. // LB policy so that connections can be created.
cc.balancerWrapper.exitIdleMode() cc.mu.Lock()
} cc.balancerWrapper.exitIdle()
cc.mu.Unlock()
func (cc *ClientConn) scWatcher() {
for {
select {
case sc, ok := <-cc.dopts.scChan:
if !ok {
return
}
cc.mu.Lock()
// TODO: load balance policy runtime change is ignored.
// We may revisit this decision in the future.
cc.sc = &sc
cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc})
cc.mu.Unlock()
case <-cc.ctx.Done():
return
}
}
} }
// waitForResolvedAddrs blocks until the resolver has provided addresses or the // waitForResolvedAddrs blocks until the resolver has provided addresses or the
@ -804,11 +699,11 @@ func init() {
internal.SubscribeToConnectivityStateChanges = func(cc *ClientConn, s grpcsync.Subscriber) func() { internal.SubscribeToConnectivityStateChanges = func(cc *ClientConn, s grpcsync.Subscriber) func() {
return cc.csMgr.pubSub.Subscribe(s) return cc.csMgr.pubSub.Subscribe(s)
} }
internal.EnterIdleModeForTesting = func(cc *ClientConn) error { internal.EnterIdleModeForTesting = func(cc *ClientConn) {
return cc.enterIdleMode() cc.idlenessMgr.EnterIdleModeForTesting()
} }
internal.ExitIdleModeForTesting = func(cc *ClientConn) error { internal.ExitIdleModeForTesting = func(cc *ClientConn) error {
return cc.exitIdleMode() return cc.idlenessMgr.ExitIdleMode()
} }
} }
@ -824,9 +719,8 @@ func (cc *ClientConn) maybeApplyDefaultServiceConfig(addrs []resolver.Address) {
} }
} }
func (cc *ClientConn) updateResolverState(s resolver.State, err error) error { func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error) error {
defer cc.firstResolveEvent.Fire() defer cc.firstResolveEvent.Fire()
cc.mu.Lock()
// Check if the ClientConn is already closed. Some fields (e.g. // Check if the ClientConn is already closed. Some fields (e.g.
// balancerWrapper) are set to nil when closing the ClientConn, and could // balancerWrapper) are set to nil when closing the ClientConn, and could
// cause nil pointer panic if we don't have this check. // cause nil pointer panic if we don't have this check.
@ -872,7 +766,7 @@ func (cc *ClientConn) updateResolverState(s resolver.State, err error) error {
if cc.sc == nil { if cc.sc == nil {
// Apply the failing LB only if we haven't received valid service config // Apply the failing LB only if we haven't received valid service config
// from the name resolver in the past. // from the name resolver in the past.
cc.applyFailingLB(s.ServiceConfig) cc.applyFailingLBLocked(s.ServiceConfig)
cc.mu.Unlock() cc.mu.Unlock()
return ret return ret
} }
@ -894,15 +788,13 @@ func (cc *ClientConn) updateResolverState(s resolver.State, err error) error {
return ret return ret
} }
// applyFailingLB is akin to configuring an LB policy on the channel which // applyFailingLBLocked is akin to configuring an LB policy on the channel which
// always fails RPCs. Here, an actual LB policy is not configured, but an always // always fails RPCs. Here, an actual LB policy is not configured, but an always
// erroring picker is configured, which returns errors with information about // erroring picker is configured, which returns errors with information about
// what was invalid in the received service config. A config selector with no // what was invalid in the received service config. A config selector with no
// service config is configured, and the connectivity state of the channel is // service config is configured, and the connectivity state of the channel is
// set to TransientFailure. // set to TransientFailure.
// func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
// Caller must hold cc.mu.
func (cc *ClientConn) applyFailingLB(sc *serviceconfig.ParseResult) {
var err error var err error
if sc.Err != nil { if sc.Err != nil {
err = status.Errorf(codes.Unavailable, "error parsing service config: %v", sc.Err) err = status.Errorf(codes.Unavailable, "error parsing service config: %v", sc.Err)
@ -910,14 +802,10 @@ func (cc *ClientConn) applyFailingLB(sc *serviceconfig.ParseResult) {
err = status.Errorf(codes.Unavailable, "illegal service config type: %T", sc.Config) err = status.Errorf(codes.Unavailable, "illegal service config type: %T", sc.Config)
} }
cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil}) cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil})
cc.blockingpicker.updatePicker(base.NewErrPicker(err)) cc.pickerWrapper.updatePicker(base.NewErrPicker(err))
cc.csMgr.updateState(connectivity.TransientFailure) cc.csMgr.updateState(connectivity.TransientFailure)
} }
func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State, err error) {
cc.balancerWrapper.updateSubConnState(sc, s, err)
}
// Makes a copy of the input addresses slice and clears out the balancer // Makes a copy of the input addresses slice and clears out the balancer
// attributes field. Addresses are passed during subconn creation and address // attributes field. Addresses are passed during subconn creation and address
// update operations. In both cases, we will clear the balancer attributes by // update operations. In both cases, we will clear the balancer attributes by
@ -932,10 +820,14 @@ func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Ad
return out return out
} }
// newAddrConn creates an addrConn for addrs and adds it to cc.conns. // newAddrConnLocked creates an addrConn for addrs and adds it to cc.conns.
// //
// Caller needs to make sure len(addrs) > 0. // Caller needs to make sure len(addrs) > 0.
func (cc *ClientConn) newAddrConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (*addrConn, error) { func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.NewSubConnOptions) (*addrConn, error) {
if cc.conns == nil {
return nil, ErrClientConnClosing
}
ac := &addrConn{ ac := &addrConn{
state: connectivity.Idle, state: connectivity.Idle,
cc: cc, cc: cc,
@ -947,12 +839,6 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address, opts balancer.NewSub
stateChan: make(chan struct{}), stateChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx) ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Track ac in cc. This needs to be done before any getTransport(...) is called.
cc.mu.Lock()
defer cc.mu.Unlock()
if cc.conns == nil {
return nil, ErrClientConnClosing
}
var err error var err error
ac.channelzID, err = channelz.RegisterSubChannel(ac, cc.channelzID, "") ac.channelzID, err = channelz.RegisterSubChannel(ac, cc.channelzID, "")
@ -968,6 +854,7 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address, opts balancer.NewSub
}, },
}) })
// Track ac in cc. This needs to be done before any getTransport(...) is called.
cc.conns[ac] = struct{}{} cc.conns[ac] = struct{}{}
return ac, nil return ac, nil
} }
@ -1174,7 +1061,7 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
} }
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) { func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.blockingpicker.pick(ctx, failfast, balancer.PickInfo{ return cc.pickerWrapper.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx, Ctx: ctx,
FullMethodName: method, FullMethodName: method,
}) })
@ -1216,12 +1103,12 @@ func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSel
func (cc *ClientConn) resolveNow(o resolver.ResolveNowOptions) { func (cc *ClientConn) resolveNow(o resolver.ResolveNowOptions) {
cc.mu.RLock() cc.mu.RLock()
r := cc.resolverWrapper cc.resolverWrapper.resolveNow(o)
cc.mu.RUnlock() cc.mu.RUnlock()
if r == nil { }
return
} func (cc *ClientConn) resolveNowLocked(o resolver.ResolveNowOptions) {
go r.resolveNow(o) cc.resolverWrapper.resolveNow(o)
} }
// ResetConnectBackoff wakes up all subchannels in transient failure and causes // ResetConnectBackoff wakes up all subchannels in transient failure and causes
@ -1253,40 +1140,32 @@ func (cc *ClientConn) Close() error {
<-cc.csMgr.pubSub.Done() <-cc.csMgr.pubSub.Done()
}() }()
// Prevent calls to enter/exit idle immediately, and ensure we are not
// currently entering/exiting idle mode.
cc.idlenessMgr.Close()
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
return ErrClientConnClosing return ErrClientConnClosing
} }
for cc.idlenessState == ccIdlenessStateExitingIdle {
cc.exitIdleCond.Wait()
}
conns := cc.conns conns := cc.conns
cc.conns = nil cc.conns = nil
cc.csMgr.updateState(connectivity.Shutdown) cc.csMgr.updateState(connectivity.Shutdown)
pWrapper := cc.blockingpicker // We can safely unlock and continue to access all fields now as
rWrapper := cc.resolverWrapper // cc.conns==nil, preventing any further operations on cc.
bWrapper := cc.balancerWrapper
idlenessMgr := cc.idlenessMgr
cc.mu.Unlock() cc.mu.Unlock()
cc.resolverWrapper.close()
// The order of closing matters here since the balancer wrapper assumes the // The order of closing matters here since the balancer wrapper assumes the
// picker is closed before it is closed. // picker is closed before it is closed.
if pWrapper != nil { cc.pickerWrapper.close()
pWrapper.close() cc.balancerWrapper.close()
}
if bWrapper != nil { <-cc.resolverWrapper.serializer.Done()
bWrapper.close() <-cc.balancerWrapper.serializer.Done()
}
if rWrapper != nil {
rWrapper.close()
}
if idlenessMgr != nil {
idlenessMgr.Close()
}
for ac := range conns { for ac := range conns {
ac.tearDown(ErrClientConnClosing) ac.tearDown(ErrClientConnClosing)
@ -1307,7 +1186,7 @@ type addrConn struct {
cc *ClientConn cc *ClientConn
dopts dialOptions dopts dialOptions
acbw balancer.SubConn acbw *acBalancerWrapper
scopts balancer.NewSubConnOptions scopts balancer.NewSubConnOptions
// transport is set when there's a viable transport (note: ac state may not be READY as LB channel // transport is set when there's a viable transport (note: ac state may not be READY as LB channel
@ -1345,7 +1224,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else { } else {
channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
} }
ac.cc.handleSubConnStateChange(ac.acbw, s, lastErr) ac.acbw.updateState(s, lastErr)
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon
@ -1849,7 +1728,7 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
if err != nil { if err != nil {
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", cc.target, err) channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", cc.target, err)
} else { } else {
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %+v", parsedTarget) channelz.Infof(logger, cc.channelzID, "parsed dial target is: %#v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme) rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb != nil { if rb != nil {
cc.parsedTarget = parsedTarget cc.parsedTarget = parsedTarget
@ -2007,32 +1886,3 @@ func (cc *ClientConn) determineAuthority() error {
channelz.Infof(logger, cc.channelzID, "Channel authority set to %q", cc.authority) channelz.Infof(logger, cc.channelzID, "Channel authority set to %q", cc.authority)
return nil return nil
} }
// initResolverWrapper creates a ccResolverWrapper, which builds the name
// resolver. This method grabs the lock to assign the newly built resolver
// wrapper to the cc.resolverWrapper field.
func (cc *ClientConn) initResolverWrapper(creds credentials.TransportCredentials) error {
rw, err := newCCResolverWrapper(cc, ccResolverWrapperOpts{
target: cc.parsedTarget,
builder: cc.resolverBuilder,
bOpts: resolver.BuildOptions{
DisableServiceConfig: cc.dopts.disableServiceConfig,
DialCreds: creds,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
},
channelzID: cc.channelzID,
})
if err != nil {
return fmt.Errorf("failed to build resolver: %v", err)
}
// Resolver implementations may report state update or error inline when
// built (or right after), and this is handled in cc.updateResolverState.
// Also, an error from the resolver might lead to a re-resolution request
// from the balancer, which is handled in resolveNow() where
// `cc.resolverWrapper` is accessed. Hence, we need to hold the lock here.
cc.mu.Lock()
cc.resolverWrapper = rw
cc.mu.Unlock()
return nil
}

View File

@ -25,7 +25,13 @@ import (
"strconv" "strconv"
) )
// A Code is an unsigned 32-bit error code as defined in the gRPC spec. // A Code is a status code defined according to the [gRPC documentation].
//
// Only the codes defined as consts in this package are valid codes. Do not use
// other code values. Behavior of other codes is implementation-specific and
// interoperability between implementations is not guaranteed.
//
// [gRPC documentation]: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
type Code uint32 type Code uint32
const ( const (

View File

@ -44,10 +44,25 @@ func (t TLSInfo) AuthType() string {
return "tls" return "tls"
} }
// cipherSuiteLookup returns the string version of a TLS cipher suite ID.
func cipherSuiteLookup(cipherSuiteID uint16) string {
for _, s := range tls.CipherSuites() {
if s.ID == cipherSuiteID {
return s.Name
}
}
for _, s := range tls.InsecureCipherSuites() {
if s.ID == cipherSuiteID {
return s.Name
}
}
return fmt.Sprintf("unknown ID: %v", cipherSuiteID)
}
// GetSecurityValue returns security info requested by channelz. // GetSecurityValue returns security info requested by channelz.
func (t TLSInfo) GetSecurityValue() ChannelzSecurityValue { func (t TLSInfo) GetSecurityValue() ChannelzSecurityValue {
v := &TLSChannelzSecurityValue{ v := &TLSChannelzSecurityValue{
StandardName: cipherSuiteLookup[t.State.CipherSuite], StandardName: cipherSuiteLookup(t.State.CipherSuite),
} }
// Currently there's no way to get LocalCertificate info from tls package. // Currently there's no way to get LocalCertificate info from tls package.
if len(t.State.PeerCertificates) > 0 { if len(t.State.PeerCertificates) > 0 {
@ -138,10 +153,39 @@ func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
return nil return nil
} }
// The following cipher suites are forbidden for use with HTTP/2 by
// https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
var tls12ForbiddenCipherSuites = map[uint16]struct{}{
tls.TLS_RSA_WITH_AES_128_CBC_SHA: {},
tls.TLS_RSA_WITH_AES_256_CBC_SHA: {},
tls.TLS_RSA_WITH_AES_128_GCM_SHA256: {},
tls.TLS_RSA_WITH_AES_256_GCM_SHA384: {},
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: {},
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: {},
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: {},
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: {},
}
// NewTLS uses c to construct a TransportCredentials based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{credinternal.CloneTLSConfig(c)} tc := &tlsCreds{credinternal.CloneTLSConfig(c)}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
// If the user did not configure a MinVersion and did not configure a
// MaxVersion < 1.2, use MinVersion=1.2, which is required by
// https://datatracker.ietf.org/doc/html/rfc7540#section-9.2
if tc.config.MinVersion == 0 && (tc.config.MaxVersion == 0 || tc.config.MaxVersion >= tls.VersionTLS12) {
tc.config.MinVersion = tls.VersionTLS12
}
// If the user did not configure CipherSuites, use all "secure" cipher
// suites reported by the TLS package, but remove some explicitly forbidden
// by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
if tc.config.CipherSuites == nil {
for _, cs := range tls.CipherSuites() {
if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok {
tc.config.CipherSuites = append(tc.config.CipherSuites, cs.ID)
}
}
}
return tc return tc
} }
@ -205,32 +249,3 @@ type TLSChannelzSecurityValue struct {
LocalCertificate []byte LocalCertificate []byte
RemoteCertificate []byte RemoteCertificate []byte
} }
var cipherSuiteLookup = map[uint16]string{
tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
tls.TLS_FALLBACK_SCSV: "TLS_FALLBACK_SCSV",
tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "TLS_RSA_WITH_AES_128_CBC_SHA256",
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305",
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
tls.TLS_AES_128_GCM_SHA256: "TLS_AES_128_GCM_SHA256",
tls.TLS_AES_256_GCM_SHA384: "TLS_AES_256_GCM_SHA384",
tls.TLS_CHACHA20_POLY1305_SHA256: "TLS_CHACHA20_POLY1305_SHA256",
}

View File

@ -46,6 +46,7 @@ func init() {
internal.WithBinaryLogger = withBinaryLogger internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
internal.WithRecvBufferPool = withRecvBufferPool
} }
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
@ -63,7 +64,6 @@ type dialOptions struct {
block bool block bool
returnLastError bool returnLastError bool
timeout time.Duration timeout time.Duration
scChan <-chan ServiceConfig
authority string authority string
binaryLogger binarylog.Logger binaryLogger binarylog.Logger
copts transport.ConnectOptions copts transport.ConnectOptions
@ -250,19 +250,6 @@ func WithDecompressor(dc Decompressor) DialOption {
}) })
} }
// WithServiceConfig returns a DialOption which has a channel to read the
// service configuration.
//
// Deprecated: service config should be received through name resolver or via
// WithDefaultServiceConfig, as specified at
// https://github.com/grpc/grpc/blob/master/doc/service_config.md. Will be
// removed in a future 1.x release.
func WithServiceConfig(c <-chan ServiceConfig) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.scChan = c
})
}
// WithConnectParams configures the ClientConn to use the provided ConnectParams // WithConnectParams configures the ClientConn to use the provided ConnectParams
// for creating and maintaining connections to servers. // for creating and maintaining connections to servers.
// //
@ -413,6 +400,17 @@ func WithTimeout(d time.Duration) DialOption {
// connections. If FailOnNonTempDialError() is set to true, and an error is // connections. If FailOnNonTempDialError() is set to true, and an error is
// returned by f, gRPC checks the error's Temporary() method to decide if it // returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address. // should try to reconnect to the network address.
//
// Note: All supported releases of Go (as of December 2023) override the OS
// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive
// with OS defaults for keepalive time and interval, use a net.Dialer that sets
// the KeepAlive field to a negative value, and sets the SO_KEEPALIVE socket
// option to true from the Control field. For a concrete example of how to do
// this, see internal.NetDialerWithTCPKeepalive().
//
// For more information, please see [issue 23459] in the Go github repo.
//
// [issue 23459]: https://github.com/golang/go/issues/23459
func WithContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption { func WithContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = f
@ -487,7 +485,7 @@ func FailOnNonTempDialError(f bool) DialOption {
// the RPCs. // the RPCs.
func WithUserAgent(s string) DialOption { func WithUserAgent(s string) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.copts.UserAgent = s o.copts.UserAgent = s + " " + grpcUA
}) })
} }
@ -637,14 +635,16 @@ func withHealthCheckFunc(f internal.HealthChecker) DialOption {
func defaultDialOptions() dialOptions { func defaultDialOptions() dialOptions {
return dialOptions{ return dialOptions{
healthCheckFunc: internal.HealthCheckFunc,
copts: transport.ConnectOptions{ copts: transport.ConnectOptions{
WriteBufferSize: defaultWriteBufSize,
ReadBufferSize: defaultReadBufSize, ReadBufferSize: defaultReadBufSize,
WriteBufferSize: defaultWriteBufSize,
UseProxy: true, UseProxy: true,
UserAgent: grpcUA,
}, },
recvBufferPool: nopBufferPool{}, bs: internalbackoff.DefaultExponential,
idleTimeout: 30 * time.Minute, healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
} }
} }
@ -705,11 +705,13 @@ func WithIdleTimeout(d time.Duration) DialOption {
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such // options are used: WithStatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored. // cases, the shared buffer pool will be ignored.
// //
// # Experimental // Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// // v1.60.0 or later.
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption { func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return withRecvBufferPool(bufferPool)
}
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool o.recvBufferPool = bufferPool
}) })

View File

@ -18,7 +18,10 @@
// Package buffer provides an implementation of an unbounded buffer. // Package buffer provides an implementation of an unbounded buffer.
package buffer package buffer
import "sync" import (
"errors"
"sync"
)
// Unbounded is an implementation of an unbounded buffer which does not use // Unbounded is an implementation of an unbounded buffer which does not use
// extra goroutines. This is typically used for passing updates from one entity // extra goroutines. This is typically used for passing updates from one entity
@ -36,6 +39,7 @@ import "sync"
type Unbounded struct { type Unbounded struct {
c chan any c chan any
closed bool closed bool
closing bool
mu sync.Mutex mu sync.Mutex
backlog []any backlog []any
} }
@ -45,32 +49,32 @@ func NewUnbounded() *Unbounded {
return &Unbounded{c: make(chan any, 1)} return &Unbounded{c: make(chan any, 1)}
} }
var errBufferClosed = errors.New("Put called on closed buffer.Unbounded")
// Put adds t to the unbounded buffer. // Put adds t to the unbounded buffer.
func (b *Unbounded) Put(t any) { func (b *Unbounded) Put(t any) error {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.closed { if b.closing {
return return errBufferClosed
} }
if len(b.backlog) == 0 { if len(b.backlog) == 0 {
select { select {
case b.c <- t: case b.c <- t:
return return nil
default: default:
} }
} }
b.backlog = append(b.backlog, t) b.backlog = append(b.backlog, t)
return nil
} }
// Load sends the earliest buffered data, if any, onto the read channel // Load sends the earliest buffered data, if any, onto the read channel returned
// returned by Get(). Users are expected to call this every time they read a // by Get(). Users are expected to call this every time they successfully read a
// value from the read channel. // value from the read channel.
func (b *Unbounded) Load() { func (b *Unbounded) Load() {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.closed {
return
}
if len(b.backlog) > 0 { if len(b.backlog) > 0 {
select { select {
case b.c <- b.backlog[0]: case b.c <- b.backlog[0]:
@ -78,6 +82,8 @@ func (b *Unbounded) Load() {
b.backlog = b.backlog[1:] b.backlog = b.backlog[1:]
default: default:
} }
} else if b.closing && !b.closed {
close(b.c)
} }
} }
@ -88,18 +94,23 @@ func (b *Unbounded) Load() {
// send the next buffered value onto the channel if there is any. // send the next buffered value onto the channel if there is any.
// //
// If the unbounded buffer is closed, the read channel returned by this method // If the unbounded buffer is closed, the read channel returned by this method
// is closed. // is closed after all data is drained.
func (b *Unbounded) Get() <-chan any { func (b *Unbounded) Get() <-chan any {
return b.c return b.c
} }
// Close closes the unbounded buffer. // Close closes the unbounded buffer. No subsequent data may be Put(), and the
// channel returned from Get() will be closed after all the data is read and
// Load() is called for the final time.
func (b *Unbounded) Close() { func (b *Unbounded) Close() {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.closed { if b.closing {
return return
} }
b.closed = true b.closing = true
close(b.c) if len(b.backlog) == 0 {
b.closed = true
close(b.c)
}
} }

View File

@ -31,6 +31,7 @@ import (
"time" "time"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
) )
const ( const (
@ -58,6 +59,12 @@ func TurnOn() {
} }
} }
func init() {
internal.ChannelzTurnOffForTesting = func() {
atomic.StoreInt32(&curState, 0)
}
}
// IsOn returns whether channelz data collection is on. // IsOn returns whether channelz data collection is on.
func IsOn() bool { func IsOn() bool {
return atomic.LoadInt32(&curState) == 1 return atomic.LoadInt32(&curState) == 1

View File

@ -36,9 +36,6 @@ var (
// "GRPC_RING_HASH_CAP". This does not override the default bounds // "GRPC_RING_HASH_CAP". This does not override the default bounds
// checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M). // checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M).
RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024) RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024)
// PickFirstLBConfig is set if we should support configuration of the
// pick_first LB policy.
PickFirstLBConfig = boolFromEnv("GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG", true)
// LeastRequestLB is set if we should support the least_request_experimental // LeastRequestLB is set if we should support the least_request_experimental
// LB policy, which can be enabled by setting the environment variable // LB policy, which can be enabled by setting the environment variable
// "GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST" to "true". // "GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST" to "true".

View File

@ -50,46 +50,7 @@ var (
// //
// When both bootstrap FileName and FileContent are set, FileName is used. // When both bootstrap FileName and FileContent are set, FileName is used.
XDSBootstrapFileContent = os.Getenv(XDSBootstrapFileContentEnv) XDSBootstrapFileContent = os.Getenv(XDSBootstrapFileContentEnv)
// XDSRingHash indicates whether ring hash support is enabled, which can be
// disabled by setting the environment variable
// "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" to "false".
XDSRingHash = boolFromEnv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH", true)
// XDSClientSideSecurity is used to control processing of security
// configuration on the client-side.
//
// Note that there is no env var protection for the server-side because we
// have a brand new API on the server-side and users explicitly need to use
// the new API to get security integration on the server.
XDSClientSideSecurity = boolFromEnv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT", true)
// XDSAggregateAndDNS indicates whether processing of aggregated cluster and
// DNS cluster is enabled, which can be disabled by setting the environment
// variable "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"
// to "false".
XDSAggregateAndDNS = boolFromEnv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", true)
// XDSRBAC indicates whether xDS configured RBAC HTTP Filter is enabled,
// which can be disabled by setting the environment variable
// "GRPC_XDS_EXPERIMENTAL_RBAC" to "false".
XDSRBAC = boolFromEnv("GRPC_XDS_EXPERIMENTAL_RBAC", true)
// XDSOutlierDetection indicates whether outlier detection support is
// enabled, which can be disabled by setting the environment variable
// "GRPC_EXPERIMENTAL_ENABLE_OUTLIER_DETECTION" to "false".
XDSOutlierDetection = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_OUTLIER_DETECTION", true)
// XDSFederation indicates whether federation support is enabled, which can
// be enabled by setting the environment variable
// "GRPC_EXPERIMENTAL_XDS_FEDERATION" to "true".
XDSFederation = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FEDERATION", true)
// XDSRLS indicates whether processing of Cluster Specifier plugins and
// support for the RLS CLuster Specifier is enabled, which can be disabled by
// setting the environment variable "GRPC_EXPERIMENTAL_XDS_RLS_LB" to
// "false".
XDSRLS = boolFromEnv("GRPC_EXPERIMENTAL_XDS_RLS_LB", true)
// C2PResolverTestOnlyTrafficDirectorURI is the TD URI for testing. // C2PResolverTestOnlyTrafficDirectorURI is the TD URI for testing.
C2PResolverTestOnlyTrafficDirectorURI = os.Getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI") C2PResolverTestOnlyTrafficDirectorURI = os.Getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI")
// XDSCustomLBPolicy indicates whether Custom LB Policies are enabled, which
// can be disabled by setting the environment variable
// "GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG" to "false".
XDSCustomLBPolicy = boolFromEnv("GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG", true)
) )

28
vendor/google.golang.org/grpc/internal/experimental.go generated vendored Normal file
View File

@ -0,0 +1,28 @@
/*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
var (
// WithRecvBufferPool is implemented by the grpc package and returns a dial
// option to configure a shared buffer pool for a grpc.ClientConn.
WithRecvBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
// RecvBufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
RecvBufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
)

View File

@ -20,7 +20,6 @@ package grpcsync
import ( import (
"context" "context"
"sync"
"google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/buffer"
) )
@ -38,8 +37,6 @@ type CallbackSerializer struct {
done chan struct{} done chan struct{}
callbacks *buffer.Unbounded callbacks *buffer.Unbounded
closedMu sync.Mutex
closed bool
} }
// NewCallbackSerializer returns a new CallbackSerializer instance. The provided // NewCallbackSerializer returns a new CallbackSerializer instance. The provided
@ -65,56 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
// callbacks to be executed by the serializer. It is not possible to add // callbacks to be executed by the serializer. It is not possible to add
// callbacks once the context passed to NewCallbackSerializer is cancelled. // callbacks once the context passed to NewCallbackSerializer is cancelled.
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool { func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
cs.closedMu.Lock() return cs.callbacks.Put(f) == nil
defer cs.closedMu.Unlock()
if cs.closed {
return false
}
cs.callbacks.Put(f)
return true
} }
func (cs *CallbackSerializer) run(ctx context.Context) { func (cs *CallbackSerializer) run(ctx context.Context) {
var backlog []func(context.Context)
defer close(cs.done) defer close(cs.done)
// TODO: when Go 1.21 is the oldest supported version, this loop and Close
// can be replaced with:
//
// context.AfterFunc(ctx, cs.callbacks.Close)
for ctx.Err() == nil { for ctx.Err() == nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Do nothing here. Next iteration of the for loop will not happen, // Do nothing here. Next iteration of the for loop will not happen,
// since ctx.Err() would be non-nil. // since ctx.Err() would be non-nil.
case callback, ok := <-cs.callbacks.Get(): case cb := <-cs.callbacks.Get():
if !ok {
return
}
cs.callbacks.Load() cs.callbacks.Load()
callback.(func(ctx context.Context))(ctx) cb.(func(context.Context))(ctx)
} }
} }
// Fetch pending callbacks if any, and execute them before returning from // Close the buffer to prevent new callbacks from being added.
// this method and closing cs.done.
cs.closedMu.Lock()
cs.closed = true
backlog = cs.fetchPendingCallbacks()
cs.callbacks.Close() cs.callbacks.Close()
cs.closedMu.Unlock()
for _, b := range backlog {
b(ctx)
}
}
func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) { // Run all pending callbacks.
var backlog []func(context.Context) for cb := range cs.callbacks.Get() {
for { cs.callbacks.Load()
select { cb.(func(context.Context))(ctx)
case b := <-cs.callbacks.Get():
backlog = append(backlog, b.(func(context.Context)))
cs.callbacks.Load()
default:
return backlog
}
} }
} }

View File

@ -26,8 +26,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"google.golang.org/grpc/grpclog"
) )
// For overriding in unit tests. // For overriding in unit tests.
@ -39,27 +37,12 @@ var timeAfterFunc = func(d time.Duration, f func()) *time.Timer {
// and exit from idle mode. // and exit from idle mode.
type Enforcer interface { type Enforcer interface {
ExitIdleMode() error ExitIdleMode() error
EnterIdleMode() error EnterIdleMode()
} }
// Manager defines the functionality required to track RPC activity on a // Manager implements idleness detection and calls the configured Enforcer to
// channel. // enter/exit idle mode when appropriate. Must be created by NewManager.
type Manager interface { type Manager struct {
OnCallBegin() error
OnCallEnd()
Close()
}
type noopManager struct{}
func (noopManager) OnCallBegin() error { return nil }
func (noopManager) OnCallEnd() {}
func (noopManager) Close() {}
// manager implements the Manager interface. It uses atomic operations to
// synchronize access to shared state and a mutex to guarantee mutual exclusion
// in a critical section.
type manager struct {
// State accessed atomically. // State accessed atomically.
lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed. lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed.
activeCallsCount int32 // Count of active RPCs; -math.MaxInt32 means channel is idle or is trying to get there. activeCallsCount int32 // Count of active RPCs; -math.MaxInt32 means channel is idle or is trying to get there.
@ -69,8 +52,7 @@ type manager struct {
// Can be accessed without atomics or mutex since these are set at creation // Can be accessed without atomics or mutex since these are set at creation
// time and read-only after that. // time and read-only after that.
enforcer Enforcer // Functionality provided by grpc.ClientConn. enforcer Enforcer // Functionality provided by grpc.ClientConn.
timeout int64 // Idle timeout duration nanos stored as an int64. timeout time.Duration
logger grpclog.LoggerV2
// idleMu is used to guarantee mutual exclusion in two scenarios: // idleMu is used to guarantee mutual exclusion in two scenarios:
// - Opposing intentions: // - Opposing intentions:
@ -88,57 +70,48 @@ type manager struct {
timer *time.Timer timer *time.Timer
} }
// ManagerOptions is a collection of options used by
// NewManager.
type ManagerOptions struct {
Enforcer Enforcer
Timeout time.Duration
Logger grpclog.LoggerV2
}
// NewManager creates a new idleness manager implementation for the // NewManager creates a new idleness manager implementation for the
// given idle timeout. // given idle timeout. It begins in idle mode.
func NewManager(opts ManagerOptions) Manager { func NewManager(enforcer Enforcer, timeout time.Duration) *Manager {
if opts.Timeout == 0 { return &Manager{
return noopManager{} enforcer: enforcer,
timeout: timeout,
actuallyIdle: true,
activeCallsCount: -math.MaxInt32,
} }
m := &manager{
enforcer: opts.Enforcer,
timeout: int64(opts.Timeout),
logger: opts.Logger,
}
m.timer = timeAfterFunc(opts.Timeout, m.handleIdleTimeout)
return m
} }
// resetIdleTimer resets the idle timer to the given duration. This method // resetIdleTimerLocked resets the idle timer to the given duration. Called
// should only be called from the timer callback. // when exiting idle mode or when the timer fires and we need to reset it.
func (m *manager) resetIdleTimer(d time.Duration) { func (m *Manager) resetIdleTimerLocked(d time.Duration) {
m.idleMu.Lock() if m.isClosed() || m.timeout == 0 || m.actuallyIdle {
defer m.idleMu.Unlock()
if m.timer == nil {
// Only close sets timer to nil. We are done.
return return
} }
// It is safe to ignore the return value from Reset() because this method is // It is safe to ignore the return value from Reset() because this method is
// only ever called from the timer callback, which means the timer has // only ever called from the timer callback or when exiting idle mode.
// already fired. if m.timer != nil {
m.timer.Reset(d) m.timer.Stop()
}
m.timer = timeAfterFunc(d, m.handleIdleTimeout)
}
func (m *Manager) resetIdleTimer(d time.Duration) {
m.idleMu.Lock()
defer m.idleMu.Unlock()
m.resetIdleTimerLocked(d)
} }
// handleIdleTimeout is the timer callback that is invoked upon expiry of the // handleIdleTimeout is the timer callback that is invoked upon expiry of the
// configured idle timeout. The channel is considered inactive if there are no // configured idle timeout. The channel is considered inactive if there are no
// ongoing calls and no RPC activity since the last time the timer fired. // ongoing calls and no RPC activity since the last time the timer fired.
func (m *manager) handleIdleTimeout() { func (m *Manager) handleIdleTimeout() {
if m.isClosed() { if m.isClosed() {
return return
} }
if atomic.LoadInt32(&m.activeCallsCount) > 0 { if atomic.LoadInt32(&m.activeCallsCount) > 0 {
m.resetIdleTimer(time.Duration(m.timeout)) m.resetIdleTimer(m.timeout)
return return
} }
@ -148,24 +121,12 @@ func (m *manager) handleIdleTimeout() {
// Set the timer to fire after a duration of idle timeout, calculated // Set the timer to fire after a duration of idle timeout, calculated
// from the time the most recent RPC completed. // from the time the most recent RPC completed.
atomic.StoreInt32(&m.activeSinceLastTimerCheck, 0) atomic.StoreInt32(&m.activeSinceLastTimerCheck, 0)
m.resetIdleTimer(time.Duration(atomic.LoadInt64(&m.lastCallEndTime) + m.timeout - time.Now().UnixNano())) m.resetIdleTimer(time.Duration(atomic.LoadInt64(&m.lastCallEndTime)-time.Now().UnixNano()) + m.timeout)
return return
} }
// This CAS operation is extremely likely to succeed given that there has // Now that we've checked that there has been no activity, attempt to enter
// been no activity since the last time we were here. Setting the // idle mode, which is very likely to succeed.
// activeCallsCount to -math.MaxInt32 indicates to OnCallBegin() that the
// channel is either in idle mode or is trying to get there.
if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) {
// This CAS operation can fail if an RPC started after we checked for
// activity at the top of this method, or one was ongoing from before
// the last time we were here. In both case, reset the timer and return.
m.resetIdleTimer(time.Duration(m.timeout))
return
}
// Now that we've set the active calls count to -math.MaxInt32, it's time to
// actually move to idle mode.
if m.tryEnterIdleMode() { if m.tryEnterIdleMode() {
// Successfully entered idle mode. No timer needed until we exit idle. // Successfully entered idle mode. No timer needed until we exit idle.
return return
@ -174,8 +135,7 @@ func (m *manager) handleIdleTimeout() {
// Failed to enter idle mode due to a concurrent RPC that kept the channel // Failed to enter idle mode due to a concurrent RPC that kept the channel
// active, or because of an error from the channel. Undo the attempt to // active, or because of an error from the channel. Undo the attempt to
// enter idle, and reset the timer to try again later. // enter idle, and reset the timer to try again later.
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) m.resetIdleTimer(m.timeout)
m.resetIdleTimer(time.Duration(m.timeout))
} }
// tryEnterIdleMode instructs the channel to enter idle mode. But before // tryEnterIdleMode instructs the channel to enter idle mode. But before
@ -185,36 +145,49 @@ func (m *manager) handleIdleTimeout() {
// Return value indicates whether or not the channel moved to idle mode. // Return value indicates whether or not the channel moved to idle mode.
// //
// Holds idleMu which ensures mutual exclusion with exitIdleMode. // Holds idleMu which ensures mutual exclusion with exitIdleMode.
func (m *manager) tryEnterIdleMode() bool { func (m *Manager) tryEnterIdleMode() bool {
// Setting the activeCallsCount to -math.MaxInt32 indicates to OnCallBegin()
// that the channel is either in idle mode or is trying to get there.
if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) {
// This CAS operation can fail if an RPC started after we checked for
// activity in the timer handler, or one was ongoing from before the
// last time the timer fired, or if a test is attempting to enter idle
// mode without checking. In all cases, abort going into idle mode.
return false
}
// N.B. if we fail to enter idle mode after this, we must re-add
// math.MaxInt32 to m.activeCallsCount.
m.idleMu.Lock() m.idleMu.Lock()
defer m.idleMu.Unlock() defer m.idleMu.Unlock()
if atomic.LoadInt32(&m.activeCallsCount) != -math.MaxInt32 { if atomic.LoadInt32(&m.activeCallsCount) != -math.MaxInt32 {
// We raced and lost to a new RPC. Very rare, but stop entering idle. // We raced and lost to a new RPC. Very rare, but stop entering idle.
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
return false return false
} }
if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 {
// An very short RPC could have come in (and also finished) after we // A very short RPC could have come in (and also finished) after we
// checked for calls count and activity in handleIdleTimeout(), but // checked for calls count and activity in handleIdleTimeout(), but
// before the CAS operation. So, we need to check for activity again. // before the CAS operation. So, we need to check for activity again.
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
return false return false
} }
// No new RPCs have come in since we last set the active calls count value // No new RPCs have come in since we set the active calls count value to
// -math.MaxInt32 in the timer callback. And since we have the lock, it is // -math.MaxInt32. And since we have the lock, it is safe to enter idle mode
// safe to enter idle mode now. // unconditionally now.
if err := m.enforcer.EnterIdleMode(); err != nil { m.enforcer.EnterIdleMode()
m.logger.Errorf("Failed to enter idle mode: %v", err)
return false
}
// Successfully entered idle mode.
m.actuallyIdle = true m.actuallyIdle = true
return true return true
} }
func (m *Manager) EnterIdleModeForTesting() {
m.tryEnterIdleMode()
}
// OnCallBegin is invoked at the start of every RPC. // OnCallBegin is invoked at the start of every RPC.
func (m *manager) OnCallBegin() error { func (m *Manager) OnCallBegin() error {
if m.isClosed() { if m.isClosed() {
return nil return nil
} }
@ -227,7 +200,7 @@ func (m *manager) OnCallBegin() error {
// Channel is either in idle mode or is in the process of moving to idle // Channel is either in idle mode or is in the process of moving to idle
// mode. Attempt to exit idle mode to allow this RPC. // mode. Attempt to exit idle mode to allow this RPC.
if err := m.exitIdleMode(); err != nil { if err := m.ExitIdleMode(); err != nil {
// Undo the increment to calls count, and return an error causing the // Undo the increment to calls count, and return an error causing the
// RPC to fail. // RPC to fail.
atomic.AddInt32(&m.activeCallsCount, -1) atomic.AddInt32(&m.activeCallsCount, -1)
@ -238,28 +211,30 @@ func (m *manager) OnCallBegin() error {
return nil return nil
} }
// exitIdleMode instructs the channel to exit idle mode. // ExitIdleMode instructs m to call the enforcer's ExitIdleMode and update m's
// // internal state.
// Holds idleMu which ensures mutual exclusion with tryEnterIdleMode. func (m *Manager) ExitIdleMode() error {
func (m *manager) exitIdleMode() error { // Holds idleMu which ensures mutual exclusion with tryEnterIdleMode.
m.idleMu.Lock() m.idleMu.Lock()
defer m.idleMu.Unlock() defer m.idleMu.Unlock()
if !m.actuallyIdle { if m.isClosed() || !m.actuallyIdle {
// This can happen in two scenarios: // This can happen in three scenarios:
// - handleIdleTimeout() set the calls count to -math.MaxInt32 and called // - handleIdleTimeout() set the calls count to -math.MaxInt32 and called
// tryEnterIdleMode(). But before the latter could grab the lock, an RPC // tryEnterIdleMode(). But before the latter could grab the lock, an RPC
// came in and OnCallBegin() noticed that the calls count is negative. // came in and OnCallBegin() noticed that the calls count is negative.
// - Channel is in idle mode, and multiple new RPCs come in at the same // - Channel is in idle mode, and multiple new RPCs come in at the same
// time, all of them notice a negative calls count in OnCallBegin and get // time, all of them notice a negative calls count in OnCallBegin and get
// here. The first one to get the lock would got the channel to exit idle. // here. The first one to get the lock would got the channel to exit idle.
// - Channel is not in idle mode, and the user calls Connect which calls
// m.ExitIdleMode.
// //
// Either way, nothing to do here. // In any case, there is nothing to do here.
return nil return nil
} }
if err := m.enforcer.ExitIdleMode(); err != nil { if err := m.enforcer.ExitIdleMode(); err != nil {
return fmt.Errorf("channel failed to exit idle mode: %v", err) return fmt.Errorf("failed to exit idle mode: %w", err)
} }
// Undo the idle entry process. This also respects any new RPC attempts. // Undo the idle entry process. This also respects any new RPC attempts.
@ -267,12 +242,12 @@ func (m *manager) exitIdleMode() error {
m.actuallyIdle = false m.actuallyIdle = false
// Start a new timer to fire after the configured idle timeout. // Start a new timer to fire after the configured idle timeout.
m.timer = timeAfterFunc(time.Duration(m.timeout), m.handleIdleTimeout) m.resetIdleTimerLocked(m.timeout)
return nil return nil
} }
// OnCallEnd is invoked at the end of every RPC. // OnCallEnd is invoked at the end of every RPC.
func (m *manager) OnCallEnd() { func (m *Manager) OnCallEnd() {
if m.isClosed() { if m.isClosed() {
return return
} }
@ -287,15 +262,17 @@ func (m *manager) OnCallEnd() {
atomic.AddInt32(&m.activeCallsCount, -1) atomic.AddInt32(&m.activeCallsCount, -1)
} }
func (m *manager) isClosed() bool { func (m *Manager) isClosed() bool {
return atomic.LoadInt32(&m.closed) == 1 return atomic.LoadInt32(&m.closed) == 1
} }
func (m *manager) Close() { func (m *Manager) Close() {
atomic.StoreInt32(&m.closed, 1) atomic.StoreInt32(&m.closed, 1)
m.idleMu.Lock() m.idleMu.Lock()
m.timer.Stop() if m.timer != nil {
m.timer = nil m.timer.Stop()
m.timer = nil
}
m.idleMu.Unlock() m.idleMu.Unlock()
} }

View File

@ -73,6 +73,11 @@ var (
// xDS-enabled server invokes this method on a grpc.Server when a particular // xDS-enabled server invokes this method on a grpc.Server when a particular
// listener moves to "not-serving" mode. // listener moves to "not-serving" mode.
DrainServerTransports any // func(*grpc.Server, string) DrainServerTransports any // func(*grpc.Server, string)
// IsRegisteredMethod returns whether the passed in method is registered as
// a method on the server.
IsRegisteredMethod any // func(*grpc.Server, string) bool
// ServerFromContext returns the server from the context.
ServerFromContext any // func(context.Context) *grpc.Server
// AddGlobalServerOptions adds an array of ServerOption that will be // AddGlobalServerOptions adds an array of ServerOption that will be
// effective globally for newly created servers. The priority will be: 1. // effective globally for newly created servers. The priority will be: 1.
// user-provided; 2. this method; 3. default values. // user-provided; 2. this method; 3. default values.
@ -177,10 +182,12 @@ var (
GRPCResolverSchemeExtraMetadata string = "xds" GRPCResolverSchemeExtraMetadata string = "xds"
// EnterIdleModeForTesting gets the ClientConn to enter IDLE mode. // EnterIdleModeForTesting gets the ClientConn to enter IDLE mode.
EnterIdleModeForTesting any // func(*grpc.ClientConn) error EnterIdleModeForTesting any // func(*grpc.ClientConn)
// ExitIdleModeForTesting gets the ClientConn to exit IDLE mode. // ExitIdleModeForTesting gets the ClientConn to exit IDLE mode.
ExitIdleModeForTesting any // func(*grpc.ClientConn) error ExitIdleModeForTesting any // func(*grpc.ClientConn) error
ChannelzTurnOffForTesting func()
) )
// HealthChecker defines the signature of the client-side LB channel health checking function. // HealthChecker defines the signature of the client-side LB channel health checking function.

View File

@ -23,7 +23,6 @@ package dns
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -37,6 +36,7 @@ import (
"google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/resolver/dns/internal"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
@ -47,15 +47,11 @@ var EnableSRVLookups = false
var logger = grpclog.Component("dns") var logger = grpclog.Component("dns")
// Globals to stub out in tests. TODO: Perhaps these two can be combined into a
// single variable for testing the resolver?
var (
newTimer = time.NewTimer
newTimerDNSResRate = time.NewTimer
)
func init() { func init() {
resolver.Register(NewBuilder()) resolver.Register(NewBuilder())
internal.TimeAfterFunc = time.After
internal.NewNetResolver = newNetResolver
internal.AddressDialer = addressDialer
} }
const ( const (
@ -70,23 +66,6 @@ const (
txtAttribute = "grpc_config=" txtAttribute = "grpc_config="
) )
var (
errMissingAddr = errors.New("dns resolver: missing address")
// Addresses ending with a colon that is supposed to be the separator
// between host and port is not allowed. E.g. "::" is a valid address as
// it is an IPv6 address (host only) and "[::]:" is invalid as it ends with
// a colon as the host and port separator
errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon")
)
var (
defaultResolver netResolver = net.DefaultResolver
// To prevent excessive re-resolution, we enforce a rate limit on DNS
// resolution requests.
minDNSResRate = 30 * time.Second
)
var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) { var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) {
var dialer net.Dialer var dialer net.Dialer
@ -94,7 +73,11 @@ var addressDialer = func(address string) func(context.Context, string, string) (
} }
} }
var newNetResolver = func(authority string) (netResolver, error) { var newNetResolver = func(authority string) (internal.NetResolver, error) {
if authority == "" {
return net.DefaultResolver, nil
}
host, port, err := parseTarget(authority, defaultDNSSvrPort) host, port, err := parseTarget(authority, defaultDNSSvrPort)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,7 +87,7 @@ var newNetResolver = func(authority string) (netResolver, error) {
return &net.Resolver{ return &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: addressDialer(authorityWithPort), Dial: internal.AddressDialer(authorityWithPort),
}, nil }, nil
} }
@ -142,13 +125,9 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
disableServiceConfig: opts.DisableServiceConfig, disableServiceConfig: opts.DisableServiceConfig,
} }
if target.URL.Host == "" { d.resolver, err = internal.NewNetResolver(target.URL.Host)
d.resolver = defaultResolver if err != nil {
} else { return nil, err
d.resolver, err = newNetResolver(target.URL.Host)
if err != nil {
return nil, err
}
} }
d.wg.Add(1) d.wg.Add(1)
@ -161,12 +140,6 @@ func (b *dnsBuilder) Scheme() string {
return "dns" return "dns"
} }
type netResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
LookupTXT(ctx context.Context, name string) (txts []string, err error)
}
// deadResolver is a resolver that does nothing. // deadResolver is a resolver that does nothing.
type deadResolver struct{} type deadResolver struct{}
@ -178,7 +151,7 @@ func (deadResolver) Close() {}
type dnsResolver struct { type dnsResolver struct {
host string host string
port string port string
resolver netResolver resolver internal.NetResolver
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
cc resolver.ClientConn cc resolver.ClientConn
@ -223,29 +196,27 @@ func (d *dnsResolver) watcher() {
err = d.cc.UpdateState(*state) err = d.cc.UpdateState(*state)
} }
var timer *time.Timer var waitTime time.Duration
if err == nil { if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30 // Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving. // seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1 backoffIndex = 1
timer = newTimerDNSResRate(minDNSResRate) waitTime = internal.MinResolutionRate
select { select {
case <-d.ctx.Done(): case <-d.ctx.Done():
timer.Stop()
return return
case <-d.rn: case <-d.rn:
} }
} else { } else {
// Poll on an error found in DNS Resolver or an error received from // Poll on an error found in DNS Resolver or an error received from
// ClientConn. // ClientConn.
timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
backoffIndex++ backoffIndex++
} }
select { select {
case <-d.ctx.Done(): case <-d.ctx.Done():
timer.Stop()
return return
case <-timer.C: case <-internal.TimeAfterFunc(waitTime):
} }
} }
} }
@ -387,7 +358,7 @@ func formatIP(addr string) (addrIP string, ok bool) {
// target: ":80" defaultPort: "443" returns host: "localhost", port: "80" // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
func parseTarget(target, defaultPort string) (host, port string, err error) { func parseTarget(target, defaultPort string) (host, port string, err error) {
if target == "" { if target == "" {
return "", "", errMissingAddr return "", "", internal.ErrMissingAddr
} }
if ip := net.ParseIP(target); ip != nil { if ip := net.ParseIP(target); ip != nil {
// target is an IPv4 or IPv6(without brackets) address // target is an IPv4 or IPv6(without brackets) address
@ -397,7 +368,7 @@ func parseTarget(target, defaultPort string) (host, port string, err error) {
if port == "" { if port == "" {
// If the port field is empty (target ends with colon), e.g. "[::1]:", // If the port field is empty (target ends with colon), e.g. "[::1]:",
// this is an error. // this is an error.
return "", "", errEndsWithColon return "", "", internal.ErrEndsWithColon
} }
// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
if host == "" { if host == "" {

View File

@ -0,0 +1,70 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains functionality internal to the dns resolver package.
package internal
import (
"context"
"errors"
"net"
"time"
)
// NetResolver groups the methods on net.Resolver that are used by the DNS
// resolver implementation. This allows the default net.Resolver instance to be
// overidden from tests.
type NetResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
LookupTXT(ctx context.Context, name string) (txts []string, err error)
}
var (
// ErrMissingAddr is the error returned when building a DNS resolver when
// the provided target name is empty.
ErrMissingAddr = errors.New("dns resolver: missing address")
// ErrEndsWithColon is the error returned when building a DNS resolver when
// the provided target name ends with a colon that is supposed to be the
// separator between host and port. E.g. "::" is a valid address as it is
// an IPv6 address (host only) and "[::]:" is invalid as it ends with a
// colon as the host and port separator
ErrEndsWithColon = errors.New("dns resolver: missing port after port-separator colon")
)
// The following vars are overridden from tests.
var (
// MinResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionRate = 30 * time.Second
// TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is
// blocked waiting for the duration to elapse.
TimeAfterFunc func(time.Duration) <-chan time.Time
// NewNetResolver returns the net.Resolver instance for the given target.
NewNetResolver func(string) (NetResolver, error)
// AddressDialer is the dialer used to dial the DNS server. It accepts the
// Host portion of the URL corresponding to the user's dial target and
// returns a dial function.
AddressDialer func(address string) func(context.Context, string, string) (net.Conn, error)
)

View File

@ -0,0 +1,29 @@
//go:build !unix
/*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
import (
"net"
)
// NetDialerWithTCPKeepalive returns a vanilla net.Dialer on non-unix platforms.
func NetDialerWithTCPKeepalive() *net.Dialer {
return &net.Dialer{}
}

View File

@ -0,0 +1,54 @@
//go:build unix
/*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
import (
"net"
"syscall"
"time"
"golang.org/x/sys/unix"
)
// NetDialerWithTCPKeepalive returns a net.Dialer that enables TCP keepalives on
// the underlying connection with OS default values for keepalive parameters.
//
// TODO: Once https://github.com/golang/go/issues/62254 lands, and the
// appropriate Go version becomes less than our least supported Go version, we
// should look into using the new API to make things more straightforward.
func NetDialerWithTCPKeepalive() *net.Dialer {
return &net.Dialer{
// Setting a negative value here prevents the Go stdlib from overriding
// the values of TCP keepalive time and interval. It also prevents the
// Go stdlib from enabling TCP keepalives by default.
KeepAlive: time.Duration(-1),
// This method is called after the underlying network socket is created,
// but before dialing the socket (or calling its connect() method). The
// combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1)
})
},
}
}

View File

@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
return nil, errors.New(msg) return nil, errors.New(msg)
} }
var localAddr net.Addr
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr, _ = la.(net.Addr)
}
var authInfo credentials.AuthInfo
if r.TLS != nil {
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
p := peer.Peer{
Addr: strAddr(r.RemoteAddr),
LocalAddr: localAddr,
AuthInfo: authInfo,
}
st := &serverHandlerTransport{ st := &serverHandlerTransport{
rw: w, rw: w,
req: r, req: r,
closedCh: make(chan struct{}), closedCh: make(chan struct{}),
writes: make(chan func()), writes: make(chan func()),
peer: p,
contentType: contentType, contentType: contentType,
contentSubtype: contentSubtype, contentSubtype: contentSubtype,
stats: stats, stats: stats,
@ -134,6 +148,8 @@ type serverHandlerTransport struct {
headerMD metadata.MD headerMD metadata.MD
peer peer.Peer
closeOnce sync.Once closeOnce sync.Once
closedCh chan struct{} // closed on Close closedCh chan struct{} // closed on Close
@ -165,7 +181,13 @@ func (ht *serverHandlerTransport) Close(err error) {
}) })
} }
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } func (ht *serverHandlerTransport) Peer() *peer.Peer {
return &peer.Peer{
Addr: ht.peer.Addr,
LocalAddr: ht.peer.LocalAddr,
AuthInfo: ht.peer.AuthInfo,
}
}
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or // strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown. // the empty string if unknown.
@ -347,10 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err return err
} }
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request. // With this transport type there will be exactly 1 stream: this HTTP request.
ctx := ht.req.Context()
var cancel context.CancelFunc var cancel context.CancelFunc
if ht.timeoutSet { if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout) ctx, cancel = context.WithTimeout(ctx, ht.timeout)
@ -370,34 +390,19 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
ht.Close(errors.New("request is done processing")) ht.Close(errors.New("request is done processing"))
}() }()
req := ht.req
s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr) req := ht.req
for _, sh := range ht.stats { s := &Stream{
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) id: 0, // irrelevant
inHeader := &stats.InHeader{ ctx: ctx,
FullMethod: s.method, requestRead: func(int) {},
RemoteAddr: ht.RemoteAddr(), cancel: cancel,
Compression: s.recvCompress, buf: newRecvBuffer(),
} st: ht,
sh.HandleRPC(s.ctx, inHeader) method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},

View File

@ -36,6 +36,7 @@ import (
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
icredentials "google.golang.org/grpc/internal/credentials" icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
@ -43,7 +44,7 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata" imetadata "google.golang.org/grpc/internal/metadata"
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/syscall" isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -176,7 +177,7 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error
if networkType == "tcp" && useProxy { if networkType == "tcp" && useProxy {
return proxyDial(ctx, address, grpcUA) return proxyDial(ctx, address, grpcUA)
} }
return (&net.Dialer{}).DialContext(ctx, networkType, address) return internal.NetDialerWithTCPKeepalive().DialContext(ctx, networkType, address)
} }
func isTemporary(err error) bool { func isTemporary(err error) bool {
@ -262,7 +263,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
} }
keepaliveEnabled := false keepaliveEnabled := false
if kp.Time != infinity { if kp.Time != infinity {
if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil { if err = isyscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil {
return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err) return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err)
} }
keepaliveEnabled = true keepaliveEnabled = true
@ -493,8 +494,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
func (t *http2Client) getPeer() *peer.Peer { func (t *http2Client) getPeer() *peer.Peer {
return &peer.Peer{ return &peer.Peer{
Addr: t.remoteAddr, Addr: t.remoteAddr,
AuthInfo: t.authInfo, // Can be nil AuthInfo: t.authInfo, // Can be nil
LocalAddr: t.localAddr,
} }
} }

View File

@ -68,18 +68,15 @@ var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct { type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context done chan struct{}
done chan struct{} conn net.Conn
conn net.Conn loopy *loopyWriter
loopy *loopyWriter readerDone chan struct{} // sync point to enable testing.
readerDone chan struct{} // sync point to enable testing. loopyWriterDone chan struct{}
writerDone chan struct{} // sync point to enable testing. peer peer.Peer
remoteAddr net.Addr inTapHandle tap.ServerInHandle
localAddr net.Addr framer *framer
authInfo credentials.AuthInfo // auth info about the connection
inTapHandle tap.ServerInHandle
framer *framer
// The max number of concurrent streams. // The max number of concurrent streams.
maxStreams uint32 maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window // controlBuf delivers all the control related tasks (e.g., window
@ -243,16 +240,18 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
} }
done := make(chan struct{}) done := make(chan struct{})
peer := peer.Peer{
Addr: conn.RemoteAddr(),
LocalAddr: conn.LocalAddr(),
AuthInfo: authInfo,
}
t := &http2Server{ t := &http2Server{
ctx: setConnection(context.Background(), rawConn),
done: done, done: done,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), peer: peer,
localAddr: conn.LocalAddr(),
authInfo: authInfo,
framer: framer, framer: framer,
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}), loopyWriterDone: make(chan struct{}),
maxStreams: config.MaxStreams, maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle, inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)}, fc: &trInFlow{limit: uint32(icwz)},
@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
bufferPool: newBufferPool(), bufferPool: newBufferPool(),
} }
t.logger = prefixLoggerForServerTransport(t) t.logger = prefixLoggerForServerTransport(t)
// Add peer information to the http2server context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())
t.controlBuf = newControlBuffer(t.done) t.controlBuf = newControlBuffer(t.done)
if dynamicWindow { if dynamicWindow {
@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl, updateFlowControl: t.updateFlowControl,
} }
} }
for _, sh := range t.stats { t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -334,7 +323,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy.run() t.loopy.run()
close(t.writerDone) close(t.loopyWriterDone)
}() }()
go t.keepalive() go t.keepalive()
return t, nil return t, nil
@ -342,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// operateHeaders takes action on the decoded headers. Returns an error if fatal // operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil. // error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error { func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration // Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock() t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock() defer t.maxStreamMu.Unlock()
@ -369,10 +358,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: streamID, id: streamID,
st: t, st: t,
buf: buf, buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)}, fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
} }
var ( var (
// if false, content-type was missing or invalid // if false, content-type was missing or invalid
@ -511,9 +501,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.state = streamReadDone s.state = streamReadDone
} }
if timeoutSet { if timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
} else { } else {
s.ctx, s.cancel = context.WithCancel(t.ctx) s.ctx, s.cancel = context.WithCancel(ctx)
} }
// Attach the received metadata to the context. // Attach the received metadata to the context.
@ -592,18 +582,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) { s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n)) t.adjustWindow(s, uint32(n))
} }
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: mdata.Copy(),
}
sh.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done() s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
@ -629,8 +607,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine. // typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context. // traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream)) { func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
defer close(t.readerDone) defer func() {
<-t.loopyWriterDone
close(t.readerDone)
}()
for { for {
t.controlBuf.throttle() t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
@ -664,7 +645,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
if err := t.operateHeaders(frame, handle); err != nil { if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err) t.Close(err)
break break
} }
@ -1242,10 +1223,6 @@ func (t *http2Server) Close(err error) {
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
} }
for _, sh := range t.stats {
connEnd := &stats.ConnEnd{}
sh.HandleConn(t.ctx, connEnd)
}
} }
// deleteStream deletes the stream s from transport's active streams. // deleteStream deletes the stream s from transport's active streams.
@ -1311,10 +1288,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
}) })
} }
func (t *http2Server) RemoteAddr() net.Addr {
return t.remoteAddr
}
func (t *http2Server) Drain(debugData string) { func (t *http2Server) Drain(debugData string) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -1397,11 +1370,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()), LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn), SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr, LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.remoteAddr, RemoteAddr: t.peer.Addr,
// RemoteName : // RemoteName :
} }
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue() s.Security = au.GetSecurityValue()
} }
s.RemoteFlowControlWindow = t.getOutFlowWindow() s.RemoteFlowControlWindow = t.getOutFlowWindow()
@ -1433,10 +1406,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
} }
} }
func (t *http2Server) getPeer() *peer.Peer { // Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{ return &peer.Peer{
Addr: t.remoteAddr, Addr: t.peer.Addr,
AuthInfo: t.authInfo, // Can be nil LocalAddr: t.peer.LocalAddr,
AuthInfo: t.peer.AuthInfo, // Can be nil
} }
} }
@ -1461,6 +1436,6 @@ func GetConnection(ctx context.Context) net.Conn {
// SetConnection adds the connection to the context to be able to get // SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also // information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection. // allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context { func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn) return context.WithValue(ctx, connectionKey{}, conn)
} }

View File

@ -28,6 +28,8 @@ import (
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"google.golang.org/grpc/internal"
) )
const proxyAuthHeaderKey = "Proxy-Authorization" const proxyAuthHeaderKey = "Proxy-Authorization"
@ -112,7 +114,7 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy // proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
// is necessary, dials, does the HTTP CONNECT handshake, and returns the // is necessary, dials, does the HTTP CONNECT handshake, and returns the
// connection. // connection.
func proxyDial(ctx context.Context, addr string, grpcUA string) (conn net.Conn, err error) { func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) {
newAddr := addr newAddr := addr
proxyURL, err := mapAddress(addr) proxyURL, err := mapAddress(addr)
if err != nil { if err != nil {
@ -122,15 +124,15 @@ func proxyDial(ctx context.Context, addr string, grpcUA string) (conn net.Conn,
newAddr = proxyURL.Host newAddr = proxyURL.Host
} }
conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", newAddr) conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr)
if err != nil { if err != nil {
return return nil, err
} }
if proxyURL != nil { if proxyURL == nil {
// proxy is disabled if proxyURL is nil. // proxy is disabled if proxyURL is nil.
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA) return conn, err
} }
return return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA)
} }
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {

View File

@ -37,6 +37,7 @@ import (
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -265,7 +266,8 @@ type Stream struct {
// headerValid indicates whether a valid header was received. Only // headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before // meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). Not valid on server side. // reading its value). Not valid on server side.
headerValid bool headerValid bool
headerWireLength int // Only set on server side.
// hdrMu protects header and trailer metadata on the server-side. // hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex hdrMu sync.Mutex
@ -425,6 +427,12 @@ func (s *Stream) Context() context.Context {
return s.ctx return s.ctx
} }
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *Stream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// Method returns the method for the stream. // Method returns the method for the stream.
func (s *Stream) Method() string { func (s *Stream) Method() string {
return s.method return s.method
@ -437,6 +445,12 @@ func (s *Stream) Status() *status.Status {
return s.status return s.status
} }
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire. Valid only on the server.
func (s *Stream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times. // SetHeader sets the header metadata. This can be called multiple times.
// Server side only. // Server side only.
// This should not be called in parallel to other data writes. // This should not be called in parallel to other data writes.
@ -698,7 +712,7 @@ type ClientTransport interface {
// Write methods for a given Stream will be called serially. // Write methods for a given Stream will be called serially.
type ServerTransport interface { type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler. // HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream)) HandleStreams(context.Context, func(*Stream))
// WriteHeader sends the header metadata for the given stream. // WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams. // WriteHeader may not be called on all streams.
@ -717,8 +731,8 @@ type ServerTransport interface {
// handlers will be terminated asynchronously. // handlers will be terminated asynchronously.
Close(err error) Close(err error)
// RemoteAddr returns the remote network address. // Peer returns the peer of the server transport.
RemoteAddr() net.Addr Peer() *peer.Peer
// Drain notifies the client this ServerTransport stops accepting new RPCs. // Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain(debugData string) Drain(debugData string)

View File

@ -153,14 +153,16 @@ func Join(mds ...MD) MD {
type mdIncomingKey struct{} type mdIncomingKey struct{}
type mdOutgoingKey struct{} type mdOutgoingKey struct{}
// NewIncomingContext creates a new context with incoming md attached. // NewIncomingContext creates a new context with incoming md attached. md must
// not be modified after calling this function.
func NewIncomingContext(ctx context.Context, md MD) context.Context { func NewIncomingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdIncomingKey{}, md) return context.WithValue(ctx, mdIncomingKey{}, md)
} }
// NewOutgoingContext creates a new context with outgoing md attached. If used // NewOutgoingContext creates a new context with outgoing md attached. If used
// in conjunction with AppendToOutgoingContext, NewOutgoingContext will // in conjunction with AppendToOutgoingContext, NewOutgoingContext will
// overwrite any previously-appended metadata. // overwrite any previously-appended metadata. md must not be modified after
// calling this function.
func NewOutgoingContext(ctx context.Context, md MD) context.Context { func NewOutgoingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md}) return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md})
} }
@ -203,7 +205,8 @@ func FromIncomingContext(ctx context.Context) (MD, bool) {
} }
// ValueFromIncomingContext returns the metadata value corresponding to the metadata // ValueFromIncomingContext returns the metadata value corresponding to the metadata
// key from the incoming metadata if it exists. Key must be lower-case. // key from the incoming metadata if it exists. Keys are matched in a case insensitive
// manner.
// //
// # Experimental // # Experimental
// //
@ -219,17 +222,16 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string {
return copyOf(v) return copyOf(v)
} }
for k, v := range md { for k, v := range md {
// We need to manually convert all keys to lower case, because MD is a // Case insenitive comparison: MD is a map, and there's no guarantee
// map, and there's no guarantee that the MD attached to the context is // that the MD attached to the context is created using our helper
// created using our helper functions. // functions.
if strings.ToLower(k) == key { if strings.EqualFold(k, key) {
return copyOf(v) return copyOf(v)
} }
} }
return nil return nil
} }
// the returned slice must not be modified in place
func copyOf(v []string) []string { func copyOf(v []string) []string {
vals := make([]string, len(v)) vals := make([]string, len(v))
copy(vals, v) copy(vals, v)

View File

@ -32,6 +32,8 @@ import (
type Peer struct { type Peer struct {
// Addr is the peer address. // Addr is the peer address.
Addr net.Addr Addr net.Addr
// LocalAddr is the local address.
LocalAddr net.Addr
// AuthInfo is the authentication information of the transport. // AuthInfo is the authentication information of the transport.
// It is nil if there is no transport security being used. // It is nil if there is no transport security being used.
AuthInfo credentials.AuthInfo AuthInfo credentials.AuthInfo

View File

@ -37,7 +37,6 @@ import (
type pickerWrapper struct { type pickerWrapper struct {
mu sync.Mutex mu sync.Mutex
done bool done bool
idle bool
blockingCh chan struct{} blockingCh chan struct{}
picker balancer.Picker picker balancer.Picker
statsHandlers []stats.Handler // to record blocking picker calls statsHandlers []stats.Handler // to record blocking picker calls
@ -53,11 +52,7 @@ func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper {
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. // updatePicker is called by UpdateBalancerState. It unblocks all blocked pick.
func (pw *pickerWrapper) updatePicker(p balancer.Picker) { func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Lock() pw.mu.Lock()
if pw.done || pw.idle { if pw.done {
// There is a small window where a picker update from the LB policy can
// race with the channel going to idle mode. If the picker is idle here,
// it is because the channel asked it to do so, and therefore it is sage
// to ignore the update from the LB policy.
pw.mu.Unlock() pw.mu.Unlock()
return return
} }
@ -210,23 +205,15 @@ func (pw *pickerWrapper) close() {
close(pw.blockingCh) close(pw.blockingCh)
} }
func (pw *pickerWrapper) enterIdleMode() { // reset clears the pickerWrapper and prepares it for being used again when idle
pw.mu.Lock() // mode is exited.
defer pw.mu.Unlock() func (pw *pickerWrapper) reset() {
if pw.done {
return
}
pw.idle = true
}
func (pw *pickerWrapper) exitIdleMode() {
pw.mu.Lock() pw.mu.Lock()
defer pw.mu.Unlock() defer pw.mu.Unlock()
if pw.done { if pw.done {
return return
} }
pw.blockingCh = make(chan struct{}) pw.blockingCh = make(chan struct{})
pw.idle = false
} }
// dropError is a wrapper error that indicates the LB policy wishes to drop the // dropError is a wrapper error that indicates the LB policy wishes to drop the

View File

@ -25,7 +25,6 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/pretty"
@ -65,19 +64,6 @@ type pfConfig struct {
} }
func (*pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { func (*pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
if !envconfig.PickFirstLBConfig {
// Prior to supporting loadbalancing configuration, the pick_first LB
// policy did not implement the balancer.ConfigParser interface. This
// meant that if a non-empty configuration was passed to it, the service
// config unmarshaling code would throw a warning log, but would
// continue using the pick_first LB policy. The code below ensures the
// same behavior is retained if the env var is not set.
if string(js) != "{}" {
logger.Warningf("Ignoring non-empty balancer configuration %q for the pick_first LB policy", string(js))
}
return nil, nil
}
var cfg pfConfig var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil { if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)

View File

@ -0,0 +1,36 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package dns implements a dns resolver to be installed as the default resolver
// in grpc.
//
// Deprecated: this package is imported by grpc and should not need to be
// imported directly by users.
package dns
import (
"google.golang.org/grpc/internal/resolver/dns"
"google.golang.org/grpc/resolver"
)
// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
//
// Deprecated: import grpc and use resolver.Get("dns") instead.
func NewBuilder() resolver.Builder {
return dns.NewBuilder()
}

View File

@ -78,12 +78,12 @@ func (r *Resolver) InitialState(s resolver.State) {
func (r *Resolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { func (r *Resolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
r.BuildCallback(target, cc, opts) r.BuildCallback(target, cc, opts)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock()
r.CC = cc r.CC = cc
if r.lastSeenState != nil { if r.lastSeenState != nil {
err := r.CC.UpdateState(*r.lastSeenState) err := r.CC.UpdateState(*r.lastSeenState)
go r.UpdateStateCallback(err) go r.UpdateStateCallback(err)
} }
r.mu.Unlock()
return r, nil return r, nil
} }
@ -105,15 +105,22 @@ func (r *Resolver) Close() {
// UpdateState calls CC.UpdateState. // UpdateState calls CC.UpdateState.
func (r *Resolver) UpdateState(s resolver.State) { func (r *Resolver) UpdateState(s resolver.State) {
r.mu.Lock() r.mu.Lock()
err := r.CC.UpdateState(s) defer r.mu.Unlock()
var err error
if r.CC == nil {
panic("cannot update state as grpc.Dial with resolver has not been called")
}
err = r.CC.UpdateState(s)
r.lastSeenState = &s r.lastSeenState = &s
r.mu.Unlock()
r.UpdateStateCallback(err) r.UpdateStateCallback(err)
} }
// ReportError calls CC.ReportError. // ReportError calls CC.ReportError.
func (r *Resolver) ReportError(err error) { func (r *Resolver) ReportError(err error) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock()
if r.CC == nil {
panic("cannot report error as grpc.Dial with resolver has not been called")
}
r.CC.ReportError(err) r.CC.ReportError(err)
r.mu.Unlock()
} }

View File

@ -136,3 +136,116 @@ func (a *AddressMap) Values() []any {
} }
return ret return ret
} }
type endpointNode struct {
addrs map[string]struct{}
}
// Equal returns whether the unordered set of addrs are the same between the
// endpoint nodes.
func (en *endpointNode) Equal(en2 *endpointNode) bool {
if len(en.addrs) != len(en2.addrs) {
return false
}
for addr := range en.addrs {
if _, ok := en2.addrs[addr]; !ok {
return false
}
}
return true
}
func toEndpointNode(endpoint Endpoint) endpointNode {
en := make(map[string]struct{})
for _, addr := range endpoint.Addresses {
en[addr.Addr] = struct{}{}
}
return endpointNode{
addrs: en,
}
}
// EndpointMap is a map of endpoints to arbitrary values keyed on only the
// unordered set of address strings within an endpoint. This map is not thread
// safe, thus it is unsafe to access concurrently. Must be created via
// NewEndpointMap; do not construct directly.
type EndpointMap struct {
endpoints map[*endpointNode]any
}
// NewEndpointMap creates a new EndpointMap.
func NewEndpointMap() *EndpointMap {
return &EndpointMap{
endpoints: make(map[*endpointNode]any),
}
}
// Get returns the value for the address in the map, if present.
func (em *EndpointMap) Get(e Endpoint) (value any, ok bool) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
return em.endpoints[endpoint], true
}
return nil, false
}
// Set updates or adds the value to the address in the map.
func (em *EndpointMap) Set(e Endpoint, value any) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
em.endpoints[endpoint] = value
return
}
em.endpoints[&en] = value
}
// Len returns the number of entries in the map.
func (em *EndpointMap) Len() int {
return len(em.endpoints)
}
// Keys returns a slice of all current map keys, as endpoints specifying the
// addresses present in the endpoint keys, in which uniqueness is determined by
// the unordered set of addresses. Thus, endpoint information returned is not
// the full endpoint data (drops duplicated addresses and attributes) but can be
// used for EndpointMap accesses.
func (em *EndpointMap) Keys() []Endpoint {
ret := make([]Endpoint, 0, len(em.endpoints))
for en := range em.endpoints {
var endpoint Endpoint
for addr := range en.addrs {
endpoint.Addresses = append(endpoint.Addresses, Address{Addr: addr})
}
ret = append(ret, endpoint)
}
return ret
}
// Values returns a slice of all current map values.
func (em *EndpointMap) Values() []any {
ret := make([]any, 0, len(em.endpoints))
for _, val := range em.endpoints {
ret = append(ret, val)
}
return ret
}
// find returns a pointer to the endpoint node in em if the endpoint node is
// already present. If not found, nil is returned. The comparisons are done on
// the unordered set of addresses within an endpoint.
func (em EndpointMap) find(e endpointNode) *endpointNode {
for endpoint := range em.endpoints {
if e.Equal(endpoint) {
return endpoint
}
}
return nil
}
// Delete removes the specified endpoint from the map.
func (em *EndpointMap) Delete(e Endpoint) {
en := toEndpointNode(e)
if entry := em.find(en); entry != nil {
delete(em.endpoints, entry)
}
}

View File

@ -240,11 +240,6 @@ type ClientConn interface {
// //
// Deprecated: Use UpdateState instead. // Deprecated: Use UpdateState instead.
NewAddress(addresses []Address) NewAddress(addresses []Address)
// NewServiceConfig is called by resolver to notify ClientConn a new
// service config. The service config should be provided as a json string.
//
// Deprecated: Use UpdateState instead.
NewServiceConfig(serviceConfig string)
// ParseServiceConfig parses the provided service config and returns an // ParseServiceConfig parses the provided service config and returns an
// object that provides the parsed config. // object that provides the parsed config.
ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult
@ -286,6 +281,11 @@ func (t Target) Endpoint() string {
return strings.TrimPrefix(endpoint, "/") return strings.TrimPrefix(endpoint, "/")
} }
// String returns a string representation of Target.
func (t Target) String() string {
return t.URL.String()
}
// Builder creates a resolver that will be used to watch name resolution updates. // Builder creates a resolver that will be used to watch name resolution updates.
type Builder interface { type Builder interface {
// Build creates a new resolver for the given target. // Build creates a new resolver for the given target.

View File

@ -1,247 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"strings"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// resolverStateUpdater wraps the single method used by ccResolverWrapper to
// report a state update from the actual resolver implementation.
type resolverStateUpdater interface {
updateResolverState(s resolver.State, err error) error
}
// ccResolverWrapper is a wrapper on top of cc for resolvers.
// It implements resolver.ClientConn interface.
type ccResolverWrapper struct {
// The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex.
cc resolverStateUpdater
channelzID *channelz.Identifier
ignoreServiceConfig bool
opts ccResolverWrapperOpts
serializer *grpcsync.CallbackSerializer // To serialize all incoming calls.
serializerCancel context.CancelFunc // To close the serializer, accessed only from close().
// All incoming (resolver --> gRPC) calls are guaranteed to execute in a
// mutually exclusive manner as they are scheduled on the serializer.
// Fields accessed *only* in these serializer callbacks, can therefore be
// accessed without a mutex.
curState resolver.State
// mu guards access to the below fields.
mu sync.Mutex
closed bool
resolver resolver.Resolver // Accessed only from outgoing calls.
}
// ccResolverWrapperOpts wraps the arguments to be passed when creating a new
// ccResolverWrapper.
type ccResolverWrapperOpts struct {
target resolver.Target // User specified dial target to resolve.
builder resolver.Builder // Resolver builder to use.
bOpts resolver.BuildOptions // Resolver build options to use.
channelzID *channelz.Identifier // Channelz identifier for the channel.
}
// newCCResolverWrapper uses the resolver.Builder to build a Resolver and
// returns a ccResolverWrapper object which wraps the newly built resolver.
func newCCResolverWrapper(cc resolverStateUpdater, opts ccResolverWrapperOpts) (*ccResolverWrapper, error) {
ctx, cancel := context.WithCancel(context.Background())
ccr := &ccResolverWrapper{
cc: cc,
channelzID: opts.channelzID,
ignoreServiceConfig: opts.bOpts.DisableServiceConfig,
opts: opts,
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
}
// Cannot hold the lock at build time because the resolver can send an
// update or error inline and these incoming calls grab the lock to schedule
// a callback in the serializer.
r, err := opts.builder.Build(opts.target, ccr, opts.bOpts)
if err != nil {
cancel()
return nil, err
}
// Any error reported by the resolver at build time that leads to a
// re-resolution request from the balancer is dropped by grpc until we
// return from this function. So, we don't have to handle pending resolveNow
// requests here.
ccr.mu.Lock()
ccr.resolver = r
ccr.mu.Unlock()
return ccr, nil
}
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.mu.Lock()
defer ccr.mu.Unlock()
// ccr.resolver field is set only after the call to Build() returns. But in
// the process of building, the resolver may send an error update which when
// propagated to the balancer may result in a re-resolution request.
if ccr.closed || ccr.resolver == nil {
return
}
ccr.resolver.ResolveNow(o)
}
func (ccr *ccResolverWrapper) close() {
ccr.mu.Lock()
if ccr.closed {
ccr.mu.Unlock()
return
}
channelz.Info(logger, ccr.channelzID, "Closing the name resolver")
// Close the serializer to ensure that no more calls from the resolver are
// handled, before actually closing the resolver.
ccr.serializerCancel()
ccr.closed = true
r := ccr.resolver
ccr.mu.Unlock()
// Give enqueued callbacks a chance to finish.
<-ccr.serializer.Done()
// Spawn a goroutine to close the resolver (since it may block trying to
// cleanup all allocated resources) and return early.
go r.Close()
}
// serializerScheduleLocked is a convenience method to schedule a function to be
// run on the serializer while holding ccr.mu.
func (ccr *ccResolverWrapper) serializerScheduleLocked(f func(context.Context)) {
ccr.mu.Lock()
ccr.serializer.Schedule(f)
ccr.mu.Unlock()
}
// UpdateState is called by resolver implementations to report new state to gRPC
// which includes addresses and service config.
func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error {
errCh := make(chan error, 1)
if s.Endpoints == nil {
s.Endpoints = make([]resolver.Endpoint, 0, len(s.Addresses))
for _, a := range s.Addresses {
ep := resolver.Endpoint{Addresses: []resolver.Address{a}, Attributes: a.BalancerAttributes}
ep.Addresses[0].BalancerAttributes = nil
s.Endpoints = append(s.Endpoints, ep)
}
}
ok := ccr.serializer.Schedule(func(context.Context) {
ccr.addChannelzTraceEvent(s)
ccr.curState = s
if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState {
errCh <- balancer.ErrBadResolverState
return
}
errCh <- nil
})
if !ok {
// The only time when Schedule() fail to add the callback to the
// serializer is when the serializer is closed, and this happens only
// when the resolver wrapper is closed.
return nil
}
return <-errCh
}
// ReportError is called by resolver implementations to report errors
// encountered during name resolution to gRPC.
func (ccr *ccResolverWrapper) ReportError(err error) {
ccr.serializerScheduleLocked(func(_ context.Context) {
channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: reporting error to cc: %v", err)
ccr.cc.updateResolverState(resolver.State{}, err)
})
}
// NewAddress is called by the resolver implementation to send addresses to
// gRPC.
func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
ccr.serializerScheduleLocked(func(_ context.Context) {
ccr.addChannelzTraceEvent(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig})
ccr.curState.Addresses = addrs
ccr.cc.updateResolverState(ccr.curState, nil)
})
}
// NewServiceConfig is called by the resolver implementation to send service
// configs to gRPC.
func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
ccr.serializerScheduleLocked(func(_ context.Context) {
channelz.Infof(logger, ccr.channelzID, "ccResolverWrapper: got new service config: %s", sc)
if ccr.ignoreServiceConfig {
channelz.Info(logger, ccr.channelzID, "Service config lookups disabled; ignoring config")
return
}
scpr := parseServiceConfig(sc)
if scpr.Err != nil {
channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err)
return
}
ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr})
ccr.curState.ServiceConfig = scpr
ccr.cc.updateResolverState(ccr.curState, nil)
})
}
// ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON)
}
// addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations.
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
var updates []string
var oldSC, newSC *ServiceConfig
var oldOK, newOK bool
if ccr.curState.ServiceConfig != nil {
oldSC, oldOK = ccr.curState.ServiceConfig.Config.(*ServiceConfig)
}
if s.ServiceConfig != nil {
newSC, newOK = s.ServiceConfig.Config.(*ServiceConfig)
}
if oldOK != newOK || (oldOK && newOK && oldSC.rawJSONString != newSC.rawJSONString) {
updates = append(updates, "service config updated")
}
if len(ccr.curState.Addresses) > 0 && len(s.Addresses) == 0 {
updates = append(updates, "resolver returned an empty address list")
} else if len(ccr.curState.Addresses) == 0 && len(s.Addresses) > 0 {
updates = append(updates, "resolver returned new addresses")
}
channelz.Infof(logger, ccr.channelzID, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
}

197
vendor/google.golang.org/grpc/resolver_wrapper.go generated vendored Normal file
View File

@ -0,0 +1,197 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"strings"
"sync"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// ccResolverWrapper is a wrapper on top of cc for resolvers.
// It implements resolver.ClientConn interface.
type ccResolverWrapper struct {
// The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex.
cc *ClientConn
ignoreServiceConfig bool
serializer *grpcsync.CallbackSerializer
serializerCancel context.CancelFunc
resolver resolver.Resolver // only accessed within the serializer
// The following fields are protected by mu. Caller must take cc.mu before
// taking mu.
mu sync.Mutex
curState resolver.State
closed bool
}
// newCCResolverWrapper initializes the ccResolverWrapper. It can only be used
// after calling start, which builds the resolver.
func newCCResolverWrapper(cc *ClientConn) *ccResolverWrapper {
ctx, cancel := context.WithCancel(cc.ctx)
return &ccResolverWrapper{
cc: cc,
ignoreServiceConfig: cc.dopts.disableServiceConfig,
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
}
}
// start builds the name resolver using the resolver.Builder in cc and returns
// any error encountered. It must always be the first operation performed on
// any newly created ccResolverWrapper, except that close may be called instead.
func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error)
ccr.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil {
return
}
opts := resolver.BuildOptions{
DisableServiceConfig: ccr.cc.dopts.disableServiceConfig,
DialCreds: ccr.cc.dopts.copts.TransportCredentials,
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
Dialer: ccr.cc.dopts.copts.Dialer,
}
var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
errCh <- err
})
return <-errCh
}
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || ccr.resolver == nil {
return
}
ccr.resolver.ResolveNow(o)
})
}
// close initiates async shutdown of the wrapper. To determine the wrapper has
// finished shutting down, the channel should block on ccr.serializer.Done()
// without cc.mu held.
func (ccr *ccResolverWrapper) close() {
channelz.Info(logger, ccr.cc.channelzID, "Closing the name resolver")
ccr.mu.Lock()
ccr.closed = true
ccr.mu.Unlock()
ccr.serializer.Schedule(func(context.Context) {
if ccr.resolver == nil {
return
}
ccr.resolver.Close()
ccr.resolver = nil
})
ccr.serializerCancel()
}
// UpdateState is called by resolver implementations to report new state to gRPC
// which includes addresses and service config.
func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error {
ccr.cc.mu.Lock()
ccr.mu.Lock()
if ccr.closed {
ccr.mu.Unlock()
ccr.cc.mu.Unlock()
return nil
}
if s.Endpoints == nil {
s.Endpoints = make([]resolver.Endpoint, 0, len(s.Addresses))
for _, a := range s.Addresses {
ep := resolver.Endpoint{Addresses: []resolver.Address{a}, Attributes: a.BalancerAttributes}
ep.Addresses[0].BalancerAttributes = nil
s.Endpoints = append(s.Endpoints, ep)
}
}
ccr.addChannelzTraceEvent(s)
ccr.curState = s
ccr.mu.Unlock()
return ccr.cc.updateResolverStateAndUnlock(s, nil)
}
// ReportError is called by resolver implementations to report errors
// encountered during name resolution to gRPC.
func (ccr *ccResolverWrapper) ReportError(err error) {
ccr.cc.mu.Lock()
ccr.mu.Lock()
if ccr.closed {
ccr.mu.Unlock()
ccr.cc.mu.Unlock()
return
}
ccr.mu.Unlock()
channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: reporting error to cc: %v", err)
ccr.cc.updateResolverStateAndUnlock(resolver.State{}, err)
}
// NewAddress is called by the resolver implementation to send addresses to
// gRPC.
func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
ccr.cc.mu.Lock()
ccr.mu.Lock()
if ccr.closed {
ccr.mu.Unlock()
ccr.cc.mu.Unlock()
return
}
s := resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}
ccr.addChannelzTraceEvent(s)
ccr.curState = s
ccr.mu.Unlock()
ccr.cc.updateResolverStateAndUnlock(s, nil)
}
// ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON)
}
// addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations.
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
var updates []string
var oldSC, newSC *ServiceConfig
var oldOK, newOK bool
if ccr.curState.ServiceConfig != nil {
oldSC, oldOK = ccr.curState.ServiceConfig.Config.(*ServiceConfig)
}
if s.ServiceConfig != nil {
newSC, newOK = s.ServiceConfig.Config.(*ServiceConfig)
}
if oldOK != newOK || (oldOK && newOK && oldSC.rawJSONString != newSC.rawJSONString) {
updates = append(updates, "service config updated")
}
if len(ccr.curState.Addresses) > 0 && len(s.Addresses) == 0 {
updates = append(updates, "resolver returned an empty address list")
} else if len(ccr.curState.Addresses) == 0 && len(s.Addresses) > 0 {
updates = append(updates, "resolver returned new addresses")
}
channelz.Infof(logger, ccr.cc.channelzID, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
}

View File

@ -70,6 +70,10 @@ func init() {
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
return srv.opts.creds return srv.opts.creds
} }
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
return srv.isRegisteredMethod(method)
}
internal.ServerFromContext = serverFromContext
internal.DrainServerTransports = func(srv *Server, addr string) { internal.DrainServerTransports = func(srv *Server, addr string) {
srv.drainServerTransports(addr) srv.drainServerTransports(addr)
} }
@ -81,6 +85,7 @@ func init() {
} }
internal.BinaryLogger = binaryLogger internal.BinaryLogger = binaryLogger
internal.JoinServerOptions = newJoinServerOption internal.JoinServerOptions = newJoinServerOption
internal.RecvBufferPool = recvBufferPool
} }
var statusOK = status.New(codes.OK, "") var statusOK = status.New(codes.OK, "")
@ -139,7 +144,8 @@ type Server struct {
channelzID *channelz.Identifier channelzID *channelz.Identifier
czData *channelzData czData *channelzData
serverWorkerChannel chan func() serverWorkerChannel chan func()
serverWorkerChannelClose func()
} }
type serverOptions struct { type serverOptions struct {
@ -578,11 +584,13 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption {
// options are used: StatsHandler, EnableTracing, or binary logging. In such // options are used: StatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored. // cases, the shared buffer pool will be ignored.
// //
// # Experimental // Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// // v1.60.0 or later.
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption { func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
return recvBufferPool(bufferPool)
}
func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.recvBufferPool = bufferPool o.recvBufferPool = bufferPool
}) })
@ -616,15 +624,14 @@ func (s *Server) serverWorker() {
// connections to reduce the time spent overall on runtime.morestack. // connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() { func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan func()) s.serverWorkerChannel = make(chan func())
s.serverWorkerChannelClose = grpcsync.OnceFunc(func() {
close(s.serverWorkerChannel)
})
for i := uint32(0); i < s.opts.numServerWorkers; i++ { for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker() go s.serverWorker()
} }
} }
func (s *Server) stopServerWorkers() {
close(s.serverWorkerChannel)
}
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
@ -806,6 +813,18 @@ func (l *listenSocket) Close() error {
// Serve returns when lis.Accept fails with fatal errors. lis will be closed when // Serve returns when lis.Accept fails with fatal errors. lis will be closed when
// this method returns. // this method returns.
// Serve will return a non-nil error unless Stop or GracefulStop is called. // Serve will return a non-nil error unless Stop or GracefulStop is called.
//
// Note: All supported releases of Go (as of December 2023) override the OS
// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive
// with OS defaults for keepalive time and interval, callers need to do the
// following two things:
// - pass a net.Listener created by calling the Listen method on a
// net.ListenConfig with the `KeepAlive` field set to a negative value. This
// will result in the Go standard library not overriding OS defaults for TCP
// keepalive interval and time. But this will also result in the Go standard
// library not enabling TCP keepalives by default.
// - override the Accept method on the passed in net.Listener and set the
// SO_KEEPALIVE socket option to enable TCP keepalives, with OS defaults.
func (s *Server) Serve(lis net.Listener) error { func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock() s.mu.Lock()
s.printf("serving") s.printf("serving")
@ -917,7 +936,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return return
} }
go func() { go func() {
s.serveStreams(st) s.serveStreams(context.Background(), st, rawConn)
s.removeConn(lisAddr, st) s.removeConn(lisAddr, st)
}() }()
} }
@ -971,18 +990,29 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
return st return st
} }
func (s *Server) serveStreams(st transport.ServerTransport) { func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
defer st.Close(errors.New("finished serving streams for the server transport")) ctx = transport.SetConnection(ctx, rawConn)
var wg sync.WaitGroup ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
RemoteAddr: st.Peer().Addr,
LocalAddr: st.Peer().LocalAddr,
})
sh.HandleConn(ctx, &stats.ConnBegin{})
}
defer func() {
st.Close(errors.New("finished serving streams for the server transport"))
for _, sh := range s.opts.statsHandlers {
sh.HandleConn(ctx, &stats.ConnEnd{})
}
}()
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) { st.HandleStreams(ctx, func(stream *transport.Stream) {
wg.Add(1)
streamQuota.acquire() streamQuota.acquire()
f := func() { f := func() {
defer streamQuota.release() defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream) s.handleStream(st, stream)
} }
@ -996,7 +1026,6 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
} }
go f() go f()
}) })
wg.Wait()
} }
var _ http.Handler = (*Server)(nil) var _ http.Handler = (*Server)(nil)
@ -1040,7 +1069,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
defer s.removeConn(listenerAddressForServeHTTP, st) defer s.removeConn(listenerAddressForServeHTTP, st)
s.serveStreams(st) s.serveStreams(r.Context(), st, nil)
} }
func (s *Server) addConn(addr string, st transport.ServerTransport) bool { func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
@ -1689,6 +1718,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
ctx := stream.Context() ctx := stream.Context()
ctx = contextWithServer(ctx, s)
var ti *traceInfo var ti *traceInfo
if EnableTracing { if EnableTracing {
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
@ -1697,7 +1727,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
tr: tr, tr: tr,
firstLine: firstLine{ firstLine: firstLine{
client: false, client: false,
remoteAddr: t.RemoteAddr(), remoteAddr: t.Peer().Addr,
}, },
} }
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
@ -1731,6 +1761,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
service := sm[:pos] service := sm[:pos]
method := sm[pos+1:] method := sm[pos+1:]
md, _ := metadata.FromIncomingContext(ctx)
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
sh.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(),
RemoteAddr: t.Peer().Addr,
LocalAddr: t.Peer().LocalAddr,
Compression: stream.RecvCompress(),
WireLength: stream.HeaderWireLength(),
Header: md,
})
}
// To have calls in stream callouts work. Will delete once all stats handler
// calls come from the gRPC layer.
stream.SetContext(ctx)
srv, knownService := s.services[service] srv, knownService := s.services[service]
if knownService { if knownService {
if md, ok := srv.methods[method]; ok { if md, ok := srv.methods[method]; ok {
@ -1820,62 +1866,68 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream
// pending RPCs on the client side will get notified by connection // pending RPCs on the client side will get notified by connection
// errors. // errors.
func (s *Server) Stop() { func (s *Server) Stop() {
s.quit.Fire() s.stop(false)
defer func() {
s.serveWG.Wait()
s.done.Fire()
}()
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
s.mu.Lock()
listeners := s.lis
s.lis = nil
conns := s.conns
s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Broadcast()
s.mu.Unlock()
for lis := range listeners {
lis.Close()
}
for _, cs := range conns {
for st := range cs {
st.Close(errors.New("Server.Stop called"))
}
}
if s.opts.numServerWorkers > 0 {
s.stopServerWorkers()
}
s.mu.Lock()
if s.events != nil {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
} }
// GracefulStop stops the gRPC server gracefully. It stops the server from // GracefulStop stops the gRPC server gracefully. It stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are // accepting new connections and RPCs and blocks until all the pending RPCs are
// finished. // finished.
func (s *Server) GracefulStop() { func (s *Server) GracefulStop() {
s.stop(true)
}
func (s *Server) stop(graceful bool) {
s.quit.Fire() s.quit.Fire()
defer s.done.Fire() defer s.done.Fire()
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
s.mu.Lock() s.mu.Lock()
if s.conns == nil { s.closeListenersLocked()
s.mu.Unlock() // Wait for serving threads to be ready to exit. Only then can we be sure no
return // new conns will be created.
s.mu.Unlock()
s.serveWG.Wait()
s.mu.Lock()
defer s.mu.Unlock()
if graceful {
s.drainAllServerTransportsLocked()
} else {
s.closeServerTransportsLocked()
} }
for lis := range s.lis { for len(s.conns) != 0 {
lis.Close() s.cv.Wait()
} }
s.lis = nil s.conns = nil
if s.opts.numServerWorkers > 0 {
// Closing the channel (only once, via grpcsync.OnceFunc) after all the
// connections have been closed above ensures that there are no
// goroutines executing the callback passed to st.HandleStreams (where
// the channel is written to).
s.serverWorkerChannelClose()
}
if s.events != nil {
s.events.Finish()
s.events = nil
}
}
// s.mu must be held by the caller.
func (s *Server) closeServerTransportsLocked() {
for _, conns := range s.conns {
for st := range conns {
st.Close(errors.New("Server.Stop called"))
}
}
}
// s.mu must be held by the caller.
func (s *Server) drainAllServerTransportsLocked() {
if !s.drain { if !s.drain {
for _, conns := range s.conns { for _, conns := range s.conns {
for st := range conns { for st := range conns {
@ -1884,22 +1936,14 @@ func (s *Server) GracefulStop() {
} }
s.drain = true s.drain = true
} }
}
// Wait for serving threads to be ready to exit. Only then can we be sure no // s.mu must be held by the caller.
// new conns will be created. func (s *Server) closeListenersLocked() {
s.mu.Unlock() for lis := range s.lis {
s.serveWG.Wait() lis.Close()
s.mu.Lock()
for len(s.conns) != 0 {
s.cv.Wait()
} }
s.conns = nil s.lis = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
} }
// contentSubtype must be lowercase // contentSubtype must be lowercase
@ -1913,11 +1957,50 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
} }
codec := encoding.GetCodec(contentSubtype) codec := encoding.GetCodec(contentSubtype)
if codec == nil { if codec == nil {
logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
return encoding.GetCodec(proto.Name) return encoding.GetCodec(proto.Name)
} }
return codec return codec
} }
type serverKey struct{}
// serverFromContext gets the Server from the context.
func serverFromContext(ctx context.Context) *Server {
s, _ := ctx.Value(serverKey{}).(*Server)
return s
}
// contextWithServer sets the Server in the context.
func contextWithServer(ctx context.Context, server *Server) context.Context {
return context.WithValue(ctx, serverKey{}, server)
}
// isRegisteredMethod returns whether the passed in method is registered as a
// method on the server. /service/method and service/method will match if the
// service and method are registered on the server.
func (s *Server) isRegisteredMethod(serviceMethod string) bool {
if serviceMethod != "" && serviceMethod[0] == '/' {
serviceMethod = serviceMethod[1:]
}
pos := strings.LastIndex(serviceMethod, "/")
if pos == -1 { // Invalid method name syntax.
return false
}
service := serviceMethod[:pos]
method := serviceMethod[pos+1:]
srv, knownService := s.services[service]
if knownService {
if _, ok := srv.methods[method]; ok {
return true
}
if _, ok := srv.streams[method]; ok {
return true
}
}
return false
}
// SetHeader sets the header metadata to be sent from the server to the client. // SetHeader sets the header metadata to be sent from the server to the client.
// The context provided must be the context passed to the server's handler. // The context provided must be the context passed to the server's handler.
// //

View File

@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.59.0" const Version = "1.60.1"

166
vendor/google.golang.org/grpc/vet.sh generated vendored
View File

@ -35,7 +35,6 @@ if [[ "$1" = "-install" ]]; then
# Install the pinned versions as defined in module tools. # Install the pinned versions as defined in module tools.
pushd ./test/tools pushd ./test/tools
go install \ go install \
golang.org/x/lint/golint \
golang.org/x/tools/cmd/goimports \ golang.org/x/tools/cmd/goimports \
honnef.co/go/tools/cmd/staticcheck \ honnef.co/go/tools/cmd/staticcheck \
github.com/client9/misspell/cmd/misspell github.com/client9/misspell/cmd/misspell
@ -77,12 +76,16 @@ fi
not grep 'func Test[^(]' *_test.go not grep 'func Test[^(]' *_test.go
not grep 'func Test[^(]' test/*.go not grep 'func Test[^(]' test/*.go
# - Check for typos in test function names
git grep 'func (s) ' -- "*_test.go" | not grep -v 'func (s) Test'
git grep 'func [A-Z]' -- "*_test.go" | not grep -v 'func Test\|Benchmark\|Example'
# - Do not import x/net/context. # - Do not import x/net/context.
not git grep -l 'x/net/context' -- "*.go" not git grep -l 'x/net/context' -- "*.go"
# - Do not import math/rand for real library code. Use internal/grpcrand for # - Do not import math/rand for real library code. Use internal/grpcrand for
# thread safety. # thread safety.
git grep -l '"math/rand"' -- "*.go" 2>&1 | not grep -v '^examples\|^stress\|grpcrand\|^benchmark\|wrr_test' git grep -l '"math/rand"' -- "*.go" 2>&1 | not grep -v '^examples\|^interop/stress\|grpcrand\|^benchmark\|wrr_test'
# - Do not use "interface{}"; use "any" instead. # - Do not use "interface{}"; use "any" instead.
git grep -l 'interface{}' -- "*.go" 2>&1 | not grep -v '\.pb\.go\|protoc-gen-go-grpc' git grep -l 'interface{}' -- "*.go" 2>&1 | not grep -v '\.pb\.go\|protoc-gen-go-grpc'
@ -94,15 +97,14 @@ git grep -l -e 'grpclog.I' --or -e 'grpclog.W' --or -e 'grpclog.E' --or -e 'grpc
not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go" not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go"
# - Ensure all usages of grpc_testing package are renamed when importing. # - Ensure all usages of grpc_testing package are renamed when importing.
not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" -- "*.go" not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" -- "*.go"
# - Ensure all xds proto imports are renamed to *pb or *grpc. # - Ensure all xds proto imports are renamed to *pb or *grpc.
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "' git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'
misspell -error . misspell -error .
# - gofmt, goimports, golint (with exceptions for generated code), go vet, # - gofmt, goimports, go vet, go mod tidy.
# go mod tidy.
# Perform these checks on each module inside gRPC. # Perform these checks on each module inside gRPC.
for MOD_FILE in $(find . -name 'go.mod'); do for MOD_FILE in $(find . -name 'go.mod'); do
MOD_DIR=$(dirname ${MOD_FILE}) MOD_DIR=$(dirname ${MOD_FILE})
@ -110,7 +112,6 @@ for MOD_FILE in $(find . -name 'go.mod'); do
go vet -all ./... | fail_on_output go vet -all ./... | fail_on_output
gofmt -s -d -l . 2>&1 | fail_on_output gofmt -s -d -l . 2>&1 | fail_on_output
goimports -l . 2>&1 | not grep -vE "\.pb\.go" goimports -l . 2>&1 | not grep -vE "\.pb\.go"
golint ./... 2>&1 | not grep -vE "/grpc_testing_not_regenerate/.*\.pb\.go:"
go mod tidy -compat=1.19 go mod tidy -compat=1.19
git status --porcelain 2>&1 | fail_on_output || \ git status --porcelain 2>&1 | fail_on_output || \
@ -119,94 +120,73 @@ for MOD_FILE in $(find . -name 'go.mod'); do
done done
# - Collection of static analysis checks # - Collection of static analysis checks
#
# TODO(dfawley): don't use deprecated functions in examples or first-party
# plugins.
# TODO(dfawley): enable ST1019 (duplicate imports) but allow for protobufs.
SC_OUT="$(mktemp)" SC_OUT="$(mktemp)"
staticcheck -go 1.19 -checks 'inherit,-ST1015,-ST1019,-SA1019' ./... > "${SC_OUT}" || true staticcheck -go 1.19 -checks 'all' ./... > "${SC_OUT}" || true
# Error if anything other than deprecation warnings are printed.
not grep -v "is deprecated:.*SA1019" "${SC_OUT}" # Error for anything other than checks that need exclusions.
# Only ignore the following deprecated types/fields/functions. grep -v "(ST1000)" "${SC_OUT}" | grep -v "(SA1019)" | grep -v "(ST1003)" | not grep -v "(ST1019)\|\(other import of\)"
not grep -Fv '.CredsBundle
.HeaderMap # Exclude underscore checks for generated code.
.Metadata is deprecated: use Attributes grep "(ST1003)" "${SC_OUT}" | not grep -v '\(.pb.go:\)\|\(code_string_test.go:\)'
.NewAddress
.NewServiceConfig # Error for duplicate imports not including grpc protos.
.Type is deprecated: use Attributes grep "(ST1019)\|\(other import of\)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
BuildVersion is deprecated channelz/grpc_channelz_v1"
balancer.ErrTransientFailure go-control-plane/envoy
balancer.Picker grpclb/grpc_lb_v1"
extDesc.Filename is deprecated health/grpc_health_v1"
github.com/golang/protobuf/jsonpb is deprecated interop/grpc_testing"
grpc.CallCustomCodec orca/v3"
grpc.Code proto/grpc_gcp"
grpc.Compressor proto/grpc_lookup_v1"
grpc.CustomCodec reflection/grpc_reflection_v1"
grpc.Decompressor reflection/grpc_reflection_v1alpha"
grpc.MaxMsgSize XXXXX PleaseIgnoreUnused'
grpc.MethodConfig
grpc.NewGZIPCompressor # Error for any package comments not in generated code.
grpc.NewGZIPDecompressor grep "(ST1000)" "${SC_OUT}" | not grep -v "\.pb\.go:"
grpc.RPCCompressor
grpc.RPCDecompressor # Only ignore the following deprecated types/fields/functions and exclude
grpc.ServiceConfig # generated code.
grpc.WithCompressor grep "(SA1019)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
grpc.WithDecompressor XXXXX Protobuf related deprecation errors:
grpc.WithDialer "github.com/golang/protobuf
grpc.WithMaxMsgSize .pb.go:
grpc.WithServiceConfig : ptypes.
grpc.WithTimeout proto.RegisterType
http.CloseNotifier XXXXX gRPC internal usage deprecation errors:
info.SecurityVersion "google.golang.org/grpc
proto is deprecated : grpc.
proto.InternalMessageInfo is deprecated : v1alpha.
proto.EnumName is deprecated : v1alphareflectionpb.
proto.ErrInternalBadWireType is deprecated BalancerAttributes is deprecated:
proto.FileDescriptor is deprecated CredsBundle is deprecated:
proto.Marshaler is deprecated Metadata is deprecated: use Attributes instead.
proto.MessageType is deprecated NewSubConn is deprecated:
proto.RegisterEnum is deprecated OverrideServerName is deprecated:
proto.RegisterFile is deprecated RemoveSubConn is deprecated:
proto.RegisterType is deprecated SecurityVersion is deprecated:
proto.RegisterExtension is deprecated
proto.RegisteredExtension is deprecated
proto.RegisteredExtensions is deprecated
proto.RegisterMapType is deprecated
proto.Unmarshaler is deprecated
Target is deprecated: Use the Target field in the BuildOptions instead. Target is deprecated: Use the Target field in the BuildOptions instead.
xxx_messageInfo_ UpdateAddresses is deprecated:
' "${SC_OUT}" UpdateSubConnState is deprecated:
balancer.ErrTransientFailure is deprecated:
# - special golint on package comments. grpc/reflection/v1alpha/reflection.proto
lint_package_comment_per_package() { XXXXX xDS deprecated fields we support
# Number of files in this go package. .ExactMatch
fileCount=$(go list -f '{{len .GoFiles}}' $1) .PrefixMatch
if [ ${fileCount} -eq 0 ]; then .SafeRegexMatch
return 0 .SuffixMatch
fi GetContainsMatch
# Number of package errors generated by golint. GetExactMatch
lintPackageCommentErrorsCount=$(golint --min_confidence 0 $1 | grep -c "should have a package comment") GetMatchSubjectAltNames
# golint complains about every file that's missing the package comment. If the GetPrefixMatch
# number of files for this package is greater than the number of errors, there's GetSafeRegexMatch
# at least one file with package comment, good. Otherwise, fail. GetSuffixMatch
if [ ${fileCount} -le ${lintPackageCommentErrorsCount} ]; then GetTlsCertificateCertificateProviderInstance
echo "Package $1 (with ${fileCount} files) is missing package comment" GetValidationContextCertificateProviderInstance
return 1 XXXXX TODO: Remove the below deprecation usages:
fi CloseNotifier
} Roots.Subjects
lint_package_comment() { XXXXX PleaseIgnoreUnused'
set +ex
count=0
for i in $(go list ./...); do
lint_package_comment_per_package "$i"
((count += $?))
done
set -ex
return $count
}
lint_package_comment
echo SUCCESS echo SUCCESS

8
vendor/modules.txt vendored
View File

@ -676,7 +676,7 @@ golang.org/x/net/internal/timeseries
golang.org/x/net/proxy golang.org/x/net/proxy
golang.org/x/net/trace golang.org/x/net/trace
golang.org/x/net/websocket golang.org/x/net/websocket
# golang.org/x/oauth2 v0.12.0 # golang.org/x/oauth2 v0.13.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/oauth2 golang.org/x/oauth2
golang.org/x/oauth2/internal golang.org/x/oauth2/internal
@ -735,7 +735,7 @@ golang.org/x/tools/internal/typeparams
# gomodules.xyz/jsonpatch/v2 v2.4.0 => github.com/gomodules/jsonpatch/v2 v2.2.0 # gomodules.xyz/jsonpatch/v2 v2.4.0 => github.com/gomodules/jsonpatch/v2 v2.2.0
## explicit; go 1.12 ## explicit; go 1.12
gomodules.xyz/jsonpatch/v2 gomodules.xyz/jsonpatch/v2
# google.golang.org/appengine v1.6.7 # google.golang.org/appengine v1.6.8
## explicit; go 1.11 ## explicit; go 1.11
google.golang.org/appengine/internal google.golang.org/appengine/internal
google.golang.org/appengine/internal/base google.golang.org/appengine/internal/base
@ -757,7 +757,7 @@ google.golang.org/genproto/googleapis/api/httpbody
## explicit; go 1.19 ## explicit; go 1.19
google.golang.org/genproto/googleapis/rpc/errdetails google.golang.org/genproto/googleapis/rpc/errdetails
google.golang.org/genproto/googleapis/rpc/status google.golang.org/genproto/googleapis/rpc/status
# google.golang.org/grpc v1.59.0 # google.golang.org/grpc v1.60.1
## explicit; go 1.19 ## explicit; go 1.19
google.golang.org/grpc google.golang.org/grpc
google.golang.org/grpc/attributes google.golang.org/grpc/attributes
@ -795,6 +795,7 @@ google.golang.org/grpc/internal/metadata
google.golang.org/grpc/internal/pretty google.golang.org/grpc/internal/pretty
google.golang.org/grpc/internal/resolver google.golang.org/grpc/internal/resolver
google.golang.org/grpc/internal/resolver/dns google.golang.org/grpc/internal/resolver/dns
google.golang.org/grpc/internal/resolver/dns/internal
google.golang.org/grpc/internal/resolver/passthrough google.golang.org/grpc/internal/resolver/passthrough
google.golang.org/grpc/internal/resolver/unix google.golang.org/grpc/internal/resolver/unix
google.golang.org/grpc/internal/serviceconfig google.golang.org/grpc/internal/serviceconfig
@ -806,6 +807,7 @@ google.golang.org/grpc/keepalive
google.golang.org/grpc/metadata google.golang.org/grpc/metadata
google.golang.org/grpc/peer google.golang.org/grpc/peer
google.golang.org/grpc/resolver google.golang.org/grpc/resolver
google.golang.org/grpc/resolver/dns
google.golang.org/grpc/resolver/manual google.golang.org/grpc/resolver/manual
google.golang.org/grpc/serviceconfig google.golang.org/grpc/serviceconfig
google.golang.org/grpc/stats google.golang.org/grpc/stats