Vendor cleanup

Signed-off-by: Madhu Rajanna <mrajanna@redhat.com>
This commit is contained in:
Madhu Rajanna
2019-01-16 18:11:54 +05:30
parent 661818bd79
commit 0f836c62fa
16816 changed files with 20 additions and 4611100 deletions

View File

@ -1,14 +0,0 @@
Please answer these questions before submitting your issue.
### What version of gRPC are you using?
### What version of Go are you using (`go version`)?
### What operating system (Linux, Windows, …) and version?
### What did you do?
If possible, provide a recipe for reproducing the error.
### What did you expect to see?
### What did you see instead?

View File

@ -1,2 +0,0 @@
daysUntilLock: 180
lockComment: false

View File

@ -1,37 +0,0 @@
language: go
matrix:
include:
- go: 1.11.x
env: VET=1 GO111MODULE=on
- go: 1.11.x
env: RACE=1 GO111MODULE=on
- go: 1.11.x
env: RUN386=1
- go: 1.11.x
env: GRPC_GO_RETRY=on
- go: 1.10.x
- go: 1.9.x
- go: 1.9.x
env: GAE=1
go_import_path: google.golang.org/grpc
before_install:
- if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi
- if [[ -n "${RUN386}" ]]; then export GOARCH=386; fi
- if [[ "${TRAVIS_EVENT_TYPE}" = "cron" && -z "${RUN386}" ]]; then RACE=1; fi
- if [[ "${TRAVIS_EVENT_TYPE}" != "cron" ]]; then VET_SKIP_PROTO=1; fi
install:
- try3() { eval "$*" || eval "$*" || eval "$*"; }
- try3 'if [[ "${GO111MODULE}" = "on" ]]; then go mod download; else make testdeps; fi'
- if [[ "${GAE}" = 1 ]]; then source ./install_gae.sh; make testappenginedeps; fi
- if [[ "${VET}" = 1 ]]; then ./vet.sh -install; fi
script:
- set -e
- if [[ "${VET}" = 1 ]]; then ./vet.sh; fi
- if [[ "${GAE}" = 1 ]]; then make testappengine; exit 0; fi
- if [[ "${RACE}" = 1 ]]; then make testrace; exit 0; fi
- make test

View File

@ -1,36 +0,0 @@
# How to contribute
We definitely welcome your patches and contributions to gRPC!
If you are new to github, please start by reading [Pull Request howto](https://help.github.com/articles/about-pull-requests/)
## Legal requirements
In order to protect both you and ourselves, you will need to sign the
[Contributor License Agreement](https://identity.linuxfoundation.org/projects/cncf).
## Guidelines for Pull Requests
How to get your contributions merged smoothly and quickly.
- Create **small PRs** that are narrowly focused on **addressing a single concern**. We often times receive PRs that are trying to fix several things at a time, but only one fix is considered acceptable, nothing gets merged and both author's & review's time is wasted. Create more PRs to address different concerns and everyone will be happy.
- For speculative changes, consider opening an issue and discussing it first. If you are suggesting a behavioral or API change, consider starting with a [gRFC proposal](https://github.com/grpc/proposal).
- Provide a good **PR description** as a record of **what** change is being made and **why** it was made. Link to a github issue if it exists.
- Don't fix code style and formatting unless you are already changing that line to address an issue. PRs with irrelevant changes won't be merged. If you do want to fix formatting or style, do that in a separate PR.
- Unless your PR is trivial, you should expect there will be reviewer comments that you'll need to address before merging. We expect you to be reasonably responsive to those comments, otherwise the PR will be closed after 2-3 weeks of inactivity.
- Maintain **clean commit history** and use **meaningful commit messages**. PRs with messy commit history are difficult to review and won't be merged. Use `rebase -i upstream/master` to curate your commit history and/or to bring in latest changes from master (but avoid rebasing in the middle of a code review).
- Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change).
- **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on.
- `make all` to test everything, OR
- `make vet` to catch vet errors
- `make test` to run the tests
- `make testrace` to run tests in race mode
- Exceptions to the rules can be made if there's a compelling reason for doing so.

View File

@ -1,80 +0,0 @@
# Compression
The preferred method for configuring message compression on both clients and
servers is to use
[`encoding.RegisterCompressor`](https://godoc.org/google.golang.org/grpc/encoding#RegisterCompressor)
to register an implementation of a compression algorithm. See
`grpc/encoding/gzip/gzip.go` for an example of how to implement one.
Once a compressor has been registered on the client-side, RPCs may be sent using
it via the
[`UseCompressor`](https://godoc.org/google.golang.org/grpc#UseCompressor)
`CallOption`. Remember that `CallOption`s may be turned into defaults for all
calls from a `ClientConn` by using the
[`WithDefaultCallOptions`](https://godoc.org/google.golang.org/grpc#WithDefaultCallOptions)
`DialOption`. If `UseCompressor` is used and the corresponding compressor has
not been installed, an `Internal` error will be returned to the application
before the RPC is sent.
Server-side, registered compressors will be used automatically to decode request
messages and encode the responses. Servers currently always respond using the
same compression method specified by the client. If the corresponding
compressor has not been registered, an `Unimplemented` status will be returned
to the client.
## Deprecated API
There is a deprecated API for setting compression as well. It is not
recommended for use. However, if you were previously using it, the following
section may be helpful in understanding how it works in combination with the new
API.
### Client-Side
There are two legacy functions and one new function to configure compression:
```go
func WithCompressor(grpc.Compressor) DialOption {}
func WithDecompressor(grpc.Decompressor) DialOption {}
func UseCompressor(name) CallOption {}
```
For outgoing requests, the following rules are applied in order:
1. If `UseCompressor` is used, messages will be compressed using the compressor
named.
* If the compressor named is not registered, an Internal error is returned
back to the client before sending the RPC.
* If UseCompressor("identity"), no compressor will be used, but "identity"
will be sent in the header to the server.
1. If `WithCompressor` is used, messages will be compressed using that
compressor implementation.
1. Otherwise, outbound messages will be uncompressed.
For incoming responses, the following rules are applied in order:
1. If `WithDecompressor` is used and it matches the message's encoding, it will
be used.
1. If a registered compressor matches the response's encoding, it will be used.
1. Otherwise, the stream will be closed and an `Unimplemented` status error will
be returned to the application.
### Server-Side
There are two legacy functions to configure compression:
```go
func RPCCompressor(grpc.Compressor) ServerOption {}
func RPCDecompressor(grpc.Decompressor) ServerOption {}
```
For incoming requests, the following rules are applied in order:
1. If `RPCDecompressor` is used and that decompressor matches the request's
encoding: it will be used.
1. If a registered compressor matches the request's encoding, it will be used.
1. Otherwise, an `Unimplemented` status will be returned to the client.
For outgoing responses, the following rules are applied in order:
1. If `RPCCompressor` is used, that compressor will be used to compress all
response messages.
1. If compression was used for the incoming request and a registered compressor
supports it, that same compression method will be used for the outgoing
response.
1. Otherwise, no compression will be used for the outgoing response.

View File

@ -1,33 +0,0 @@
# Concurrency
In general, gRPC-go provides a concurrency-friendly API. What follows are some
guidelines.
## Clients
A [ClientConn][client-conn] can safely be accessed concurrently. Using
[helloworld][helloworld] as an example, one could share the `ClientConn` across
multiple goroutines to create multiple `GreeterClient` types. In this case, RPCs
would be sent in parallel.
## Streams
When using streams, one must take care to avoid calling either `SendMsg` or
`RecvMsg` multiple times against the same [Stream][stream] from different
goroutines. In other words, it's safe to have a goroutine calling `SendMsg` and
another goroutine calling `RecvMsg` on the same stream at the same time. But it
is not safe to call `SendMsg` on the same stream in different goroutines, or to
call `RecvMsg` on the same stream in different goroutines.
## Servers
Each RPC handler attached to a registered server will be invoked in its own
goroutine. For example, [SayHello][say-hello] will be invoked in its own
goroutine. The same is true for service handlers for streaming RPCs, as seen
in the route guide example [here][route-guide-stream].
[helloworld]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_client/main.go#L43
[client-conn]: https://godoc.org/google.golang.org/grpc#ClientConn
[stream]: https://godoc.org/google.golang.org/grpc#Stream
[say-hello]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_server/main.go#L41
[route-guide-stream]: https://github.com/grpc/grpc-go/blob/master/examples/route_guide/server/server.go#L126

View File

@ -1,146 +0,0 @@
# Encoding
The gRPC API for sending and receiving is based upon *messages*. However,
messages cannot be transmitted directly over a network; they must first be
converted into *bytes*. This document describes how gRPC-Go converts messages
into bytes and vice-versa for the purposes of network transmission.
## Codecs (Serialization and Deserialization)
A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and
deserialize a byte slice back into a message (`Unmarshal`). `Codec`s are
registered by name into a global registry maintained in the `encoding` package.
### Implementing a `Codec`
A typical `Codec` will be implemented in its own package with an `init` function
that registers itself, and is imported anonymously. For example:
```go
package proto
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCodec(protoCodec{})
}
// ... implementation of protoCodec ...
```
For an example, gRPC's implementation of the `proto` codec can be found in
[`encoding/proto`](https://godoc.org/google.golang.org/grpc/encoding/proto).
### Using a `Codec`
By default, gRPC registers and uses the "proto" codec, so it is not necessary to
do this in your own code to send and receive proto messages. To use another
`Codec` from a client or server:
```go
package myclient
import _ "path/to/another/codec"
```
`Codec`s, by definition, must be symmetric, so the same desired `Codec` should
be registered in both client and server binaries.
On the client-side, to specify a `Codec` to use for message transmission, the
`CallOption` `CallContentSubtype` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec")))
```
When specified in either of these ways, messages will be encoded using this
codec and sent along with headers indicating the codec (`content-type` set to
`application/grpc+<codec name>`).
On the server-side, using a `Codec` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is encoded with the content
sub-type supported by a registered `Codec`, it will be used automatically for
decoding the request and encoding the response. Otherwise, for
backward-compatibility reasons, gRPC will attempt to use the "proto" codec. In
an upcoming change (tracked in [this
issue](https://github.com/grpc/grpc-go/issues/1824)), such requests will be
rejected with status code `Unimplemented` instead.
## Compressors (Compression and Decompression)
Sometimes, the resulting serialization of a message is not space-efficient, and
it may be beneficial to compress this byte stream before transmitting it over
the network. To facilitate this operation, gRPC supports a mechanism for
performing compression and decompression.
A `Compressor` contains code to compress and decompress by wrapping `io.Writer`s
and `io.Reader`s, respectively. (The form of `Compress` and `Decompress` were
chosen to most closely match Go's standard package
[implementations](https://golang.org/pkg/compress/) of compressors. Like
`Codec`s, `Compressor`s are registered by name into a global registry maintained
in the `encoding` package.
### Implementing a `Compressor`
A typical `Compressor` will be implemented in its own package with an `init`
function that registers itself, and is imported anonymously. For example:
```go
package gzip
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCompressor(compressor{})
}
// ... implementation of compressor ...
```
An implementation of a `gzip` compressor can be found in
[`encoding/gzip`](https://godoc.org/google.golang.org/grpc/encoding/gzip).
### Using a `Compressor`
By default, gRPC does not register or use any compressors. To use a
`Compressor` from a client or server:
```go
package myclient
import _ "google.golang.org/grpc/encoding/gzip"
```
`Compressor`s, by definition, must be symmetric, so the same desired
`Compressor` should be registered in both client and server binaries.
On the client-side, to specify a `Compressor` to use for message transmission,
the `CallOption` `UseCompressor` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.UseCompresor("gzip")))
```
When specified in either of these ways, messages will be compressed using this
compressor and sent along with headers indicating the compressor
(`content-coding` set to `<compressor name>`).
On the server-side, using a `Compressor` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is compressed with the
content coding supported by a registered `Compressor`, it will be used
automatically for decompressing the request and compressing the response.
Otherwise, the request will be rejected with status code `Unimplemented`.

View File

@ -1,182 +0,0 @@
# Mocking Service for gRPC
[Example code unary RPC](https://github.com/grpc/grpc-go/tree/master/examples/helloworld/mock_helloworld)
[Example code streaming RPC](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/mock_routeguide)
## Why?
To test client-side logic without the overhead of connecting to a real server. Mocking enables users to write light-weight unit tests to check functionalities on client-side without invoking RPC calls to a server.
## Idea: Mock the client stub that connects to the server.
We use Gomock to mock the client interface (in the generated code) and programmatically set its methods to expect and return pre-determined values. This enables users to write tests around the client logic and use this mocked stub while making RPC calls.
## How to use Gomock?
Documentation on Gomock can be found [here](https://github.com/golang/mock).
A quick reading of the documentation should enable users to follow the code below.
Consider a gRPC service based on following proto file:
```proto
//helloworld.proto
package helloworld;
message HelloRequest {
string name = 1;
}
message HelloReply {
string name = 1;
}
service Greeter {
rpc SayHello (HelloRequest) returns (HelloReply) {}
}
```
The generated file helloworld.pb.go will have a client interface for each service defined in the proto file. This interface will have methods corresponding to each rpc inside that service.
```Go
type GreeterClient interface {
SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error)
}
```
The generated code also contains a struct that implements this interface.
```Go
type greeterClient struct {
cc *grpc.ClientConn
}
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error){
// ...
// gRPC specific code here
// ...
}
```
Along with this the generated code has a method to create an instance of this struct.
```Go
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient
```
The user code uses this function to create an instance of the struct greeterClient which then can be used to make rpc calls to the server.
We will mock this interface GreeterClient and use an instance of that mock to make rpc calls. These calls instead of going to server will return pre-determined values.
To create a mock well use [mockgen](https://github.com/golang/mock#running-mockgen).
From the directory ``` examples/helloworld/ ``` run ``` mockgen google.golang.org/grpc/examples/helloworld/helloworld GreeterClient > mock_helloworld/hw_mock.go ```
Notice that in the above command we specify GreeterClient as the interface to be mocked.
The user test code can import the package generated by mockgen along with library package gomock to write unit tests around client-side logic.
```Go
import "github.com/golang/mock/gomock"
import hwmock "google.golang.org/grpc/examples/helloworld/mock_helloworld"
```
An instance of the mocked interface can be created as:
```Go
mockGreeterClient := hwmock.NewMockGreeterClient(ctrl)
```
This mocked object can be programmed to expect calls to its methods and return pre-determined values. For instance, we can program mockGreeterClient to expect a call to its method SayHello and return a HelloReply with message “Mocked RPC”.
```Go
mockGreeterClient.EXPECT().SayHello(
gomock.Any(), // expect any value for first parameter
gomock.Any(), // expect any value for second parameter
).Return(&helloworld.HelloReply{Message: “Mocked RPC”}, nil)
```
gomock.Any() indicates that the parameter can have any value or type. We can indicate specific values for built-in types with gomock.Eq().
However, if the test code needs to specify the parameter to have a proto message type, we can replace gomock.Any() with an instance of a struct that implements gomock.Matcher interface.
```Go
type rpcMsg struct {
msg proto.Message
}
func (r *rpcMsg) Matches(msg interface{}) bool {
m, ok := msg.(proto.Message)
if !ok {
return false
}
return proto.Equal(m, r.msg)
}
func (r *rpcMsg) String() string {
return fmt.Sprintf("is %s", r.msg)
}
...
req := &helloworld.HelloRequest{Name: "unit_test"}
mockGreeterClient.EXPECT().SayHello(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&helloworld.HelloReply{Message: "Mocked Interface"}, nil)
```
## Mock streaming RPCs:
For our example we consider the case of bi-directional streaming RPCs. Concretely, we'll write a test for RouteChat function from the route guide example to demonstrate how to write mocks for streams.
RouteChat is a bi-directional streaming RPC, which means calling RouteChat returns a stream that can __Send__ and __Recv__ messages to and from the server, respectively. We'll start by creating a mock of this stream interface returned by RouteChat and then we'll mock the client interface and set expectation on the method RouteChat to return our mocked stream.
### Generating mocking code:
Like before we'll use [mockgen](https://github.com/golang/mock#running-mockgen). From the `examples/route_guide` directory run: `mockgen google.golang.org/grpc/examples/route_guide/routeguide RouteGuideClient,RouteGuide_RouteChatClient > mock_route_guide/rg_mock.go`
Notice that we are mocking both client(`RouteGuideClient`) and stream(`RouteGuide_RouteChatClient`) interfaces here.
This will create a file `rg_mock.go` under directory `mock_route_guide`. This file contins all the mocking code we need to write our test.
In our test code, like before, we import the this mocking code along with the generated code
```go
import (
rgmock "google.golang.org/grpc/examples/route_guide/mock_routeguide"
rgpb "google.golang.org/grpc/examples/route_guide/routeguide"
)
```
Now conside a test that takes the RouteGuide client object as a parameter, makes a RouteChat rpc call and sends a message on the resulting stream. Furthermore, this test expects to see the same message to be received on the stream.
```go
var msg = ...
// Creates a RouteChat call and sends msg on it.
// Checks if the received message was equal to msg.
func testRouteChat(client rgb.RouteChatClient) error{
...
}
```
We can inject our mock in here by simply passing it as an argument to the method.
Creating mock for stream interface:
```go
stream := rgmock.NewMockRouteGuide_RouteChatClient(ctrl)
}
```
Setting Expectations:
```go
stream.EXPECT().Send(gomock.Any()).Return(nil)
stream.EXPECT().Recv().Return(msg, nil)
```
Creating mock for client interface:
```go
rgclient := rgmock.NewMockRouteGuideClient(ctrl)
```
Setting Expectations:
```go
rgclient.EXPECT().RouteChat(gomock.Any()).Return(stream, nil)
```

View File

@ -1,78 +0,0 @@
# Authentication
As outlined in the [gRPC authentication guide](https://grpc.io/docs/guides/auth.html) there are a number of different mechanisms for asserting identity between an client and server. We'll present some code-samples here demonstrating how to provide TLS support encryption and identity assertions as well as passing OAuth2 tokens to services that support it.
# Enabling TLS on a gRPC client
```Go
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
```
# Enabling TLS on a gRPC server
```Go
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
log.Fatalf("Failed to generate credentials %v", err)
}
lis, err := net.Listen("tcp", ":0")
server := grpc.NewServer(grpc.Creds(creds))
...
server.Serve(lis)
```
# OAuth2
For an example of how to configure client and server to use OAuth2 tokens, see
[here](https://github.com/grpc/grpc-go/blob/master/examples/oauth/).
## Validating a token on the server
Clients may use
[metadata.MD](https://godoc.org/google.golang.org/grpc/metadata#MD)
to store tokens and other authentication-related data. To gain access to the
`metadata.MD` object, a server may use
[metadata.FromIncomingContext](https://godoc.org/google.golang.org/grpc/metadata#FromIncomingContext).
With a reference to `metadata.MD` on the server, one needs to simply lookup the
`authorization` key. Note, all keys stored within `metadata.MD` are normalized
to lowercase. See [here](https://godoc.org/google.golang.org/grpc/metadata#New).
It is possible to configure token validation for all RPCs using an interceptor.
A server may configure either a
[grpc.UnaryInterceptor](https://godoc.org/google.golang.org/grpc#UnaryInterceptor)
or a
[grpc.StreamInterceptor](https://godoc.org/google.golang.org/grpc#StreamInterceptor).
## Adding a token to all outgoing client RPCs
To send an OAuth2 token with each RPC, a client may configure the
`grpc.DialOption`
[grpc.WithPerRPCCredentials](https://godoc.org/google.golang.org/grpc#WithPerRPCCredentials).
Alternatively, a client may also use the `grpc.CallOption`
[grpc.PerRPCCredentials](https://godoc.org/google.golang.org/grpc#PerRPCCredentials)
on each invocation of an RPC.
To create a `credentials.PerRPCCredentials`, use
[oauth.NewOauthAccess](https://godoc.org/google.golang.org/grpc/credentials/oauth#NewOauthAccess).
Note, the OAuth2 implementation of `grpc.PerRPCCredentials` requires a client to use
[grpc.WithTransportCredentials](https://godoc.org/google.golang.org/grpc#WithTransportCredentials)
to prevent any insecure transmission of tokens.
# Authenticating with Google
## Google Compute Engine (GCE)
```Go
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))
```
## JWT
```Go
jwtCreds, err := oauth.NewServiceAccountFromFile(*serviceAccountKeyFile, *oauthScope)
if err != nil {
log.Fatalf("Failed to create JWT credentials: %v", err)
}
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(jwtCreds))
```

View File

@ -1,227 +0,0 @@
# Metadata
gRPC supports sending metadata between client and server.
This doc shows how to send and receive metadata in gRPC-go.
## Background
Four kinds of service method:
- [Unary RPC](https://grpc.io/docs/guides/concepts.html#unary-rpc)
- [Server streaming RPC](https://grpc.io/docs/guides/concepts.html#server-streaming-rpc)
- [Client streaming RPC](https://grpc.io/docs/guides/concepts.html#client-streaming-rpc)
- [Bidirectional streaming RPC](https://grpc.io/docs/guides/concepts.html#bidirectional-streaming-rpc)
And concept of [metadata](https://grpc.io/docs/guides/concepts.html#metadata).
## Constructing metadata
A metadata can be created using package [metadata](https://godoc.org/google.golang.org/grpc/metadata).
The type MD is actually a map from string to a list of strings:
```go
type MD map[string][]string
```
Metadata can be read like a normal map.
Note that the value type of this map is `[]string`,
so that users can attach multiple values using a single key.
### Creating a new metadata
A metadata can be created from a `map[string]string` using function `New`:
```go
md := metadata.New(map[string]string{"key1": "val1", "key2": "val2"})
```
Another way is to use `Pairs`.
Values with the same key will be merged into a list:
```go
md := metadata.Pairs(
"key1", "val1",
"key1", "val1-2", // "key1" will have map value []string{"val1", "val1-2"}
"key2", "val2",
)
```
__Note:__ all the keys will be automatically converted to lowercase,
so "key1" and "kEy1" will be the same key and their values will be merged into the same list.
This happens for both `New` and `Pairs`.
### Storing binary data in metadata
In metadata, keys are always strings. But values can be strings or binary data.
To store binary data value in metadata, simply add "-bin" suffix to the key.
The values with "-bin" suffixed keys will be encoded when creating the metadata:
```go
md := metadata.Pairs(
"key", "string value",
"key-bin", string([]byte{96, 102}), // this binary data will be encoded (base64) before sending
// and will be decoded after being transferred.
)
```
## Retrieving metadata from context
Metadata can be retrieved from context using `FromIncomingContext`:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
## Sending and receiving metadata - client side
[//]: # "TODO: uncomment next line after example source added"
[//]: # "Real metadata sending and receiving examples are available [here](TODO:example_dir)."
### Sending metadata
There are two ways to send metadata to the server. The recommended way is to append kv pairs to the context using
`AppendToOutgoingContext`. This can be used with or without existing metadata on the context. When there is no prior
metadata, metadata is added; when metadata already exists on the context, kv pairs are merged in.
```go
// create a new context with some metadata
ctx := metadata.AppendToOutgoingContext(ctx, "k1", "v1", "k1", "v2", "k2", "v3")
// later, add some more metadata to the context (e.g. in an interceptor)
ctx := metadata.AppendToOutgoingContext(ctx, "k3", "v4")
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```
Alternatively, metadata may be attached to the context using `NewOutgoingContext`. However, this
replaces any existing metadata in the context, so care must be taken to preserve the existing
metadata if desired. This is slower than using `AppendToOutgoingContext`. An example of this
is below:
```go
// create a new context with some metadata
md := metadata.Pairs("k1", "v1", "k1", "v2", "k2", "v3")
ctx := metadata.NewOutgoingContext(context.Background(), md)
// later, add some more metadata to the context (e.g. in an interceptor)
md, _ := metadata.FromOutgoingContext(ctx)
newMD := metadata.Pairs("k3", "v3")
ctx = metadata.NewContext(ctx, metadata.Join(metadata.New(send), newMD))
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```
### Receiving metadata
Metadata that a client can receive includes header and trailer.
#### Unary call
Header and trailer sent along with a unary call can be retrieved using function [Header](https://godoc.org/google.golang.org/grpc#Header) and [Trailer](https://godoc.org/google.golang.org/grpc#Trailer) in [CallOption](https://godoc.org/google.golang.org/grpc#CallOption):
```go
var header, trailer metadata.MD // variable to store header and trailer
r, err := client.SomeRPC(
ctx,
someRequest,
grpc.Header(&header), // will retrieve header
grpc.Trailer(&trailer), // will retrieve trailer
)
// do something with header and trailer
```
#### Streaming call
For streaming calls including:
- Server streaming RPC
- Client streaming RPC
- Bidirectional streaming RPC
Header and trailer can be retrieved from the returned stream using function `Header` and `Trailer` in interface [ClientStream](https://godoc.org/google.golang.org/grpc#ClientStream):
```go
stream, err := client.SomeStreamingRPC(ctx)
// retrieve header
header, err := stream.Header()
// retrieve trailer
trailer := stream.Trailer()
```
## Sending and receiving metadata - server side
[//]: # "TODO: uncomment next line after example source added"
[//]: # "Real metadata sending and receiving examples are available [here](TODO:example_dir)."
### Receiving metadata
To read metadata sent by the client, the server needs to retrieve it from RPC context.
If it is a unary call, the RPC handler's context can be used.
For streaming calls, the server needs to get context from the stream.
#### Unary call
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
#### Streaming call
```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream
// do something with metadata
}
```
### Sending metadata
#### Unary call
To send header and trailer to client in unary call, the server can call [SendHeader](https://godoc.org/google.golang.org/grpc#SendHeader) and [SetTrailer](https://godoc.org/google.golang.org/grpc#SetTrailer) functions in module [grpc](https://godoc.org/google.golang.org/grpc).
These two functions take a context as the first parameter.
It should be the RPC handler's context or one derived from it:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
// create and send header
header := metadata.Pairs("header-key", "val")
grpc.SendHeader(ctx, header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
grpc.SetTrailer(ctx, trailer)
}
```
#### Streaming call
For streaming calls, header and trailer can be sent using function `SendHeader` and `SetTrailer` in interface [ServerStream](https://godoc.org/google.golang.org/grpc#ServerStream):
```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
// create and send header
header := metadata.Pairs("header-key", "val")
stream.SendHeader(header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
stream.SetTrailer(trailer)
}
```

View File

@ -1,46 +0,0 @@
# Keepalive
gRPC sends http2 pings on the transport to detect if the connection is down. If
the ping is not acknowledged by the other side within a certain period, the
connection will be close. Note that pings are only necessary when there's no
activity on the connection.
For how to configure keepalive, see
https://godoc.org/google.golang.org/grpc/keepalive for the options.
## What should I set?
It should be sufficient for most users to set [client
parameters](https://godoc.org/google.golang.org/grpc/keepalive) as a [dial
option](https://godoc.org/google.golang.org/grpc#WithKeepaliveParams).
## What will happen?
(The behavior described here is specific for gRPC-go, it might be slightly
different in other languages.)
When there's no activity on a connection (note that an ongoing stream results in
__no activity__ when there's no message being sent), after `Time`, a ping will
be sent by the client and the server will send a ping ack when it gets the ping.
Client will wait for `Timeout`, and check if there's any activity on the
connection during this period (a ping ack is an activity).
## What about server side?
Server has similar `Time` and `Timeout` settings as client. Server can also
configure connection max-age. See [server
parameters](https://godoc.org/google.golang.org/grpc/keepalive#ServerParameters)
for details.
### Enforcement policy
[Enforcement
policy](https://godoc.org/google.golang.org/grpc/keepalive#ServerParameters) is
a special setting on server side to protect server from malicious or misbehaving
clients.
Server sends GOAWAY with ENHANCE_YOUR_CALM and close the connection when bad
behaviors are detected:
- Client sends too frequent pings
- Client sends pings when there's no stream and this is disallowed by server
config

View File

@ -1,49 +0,0 @@
# Log Levels
This document describes the different log levels supported by the grpc-go
library, and under what conditions they should be used.
### Info
Info messages are for informational purposes and may aid in the debugging of
applications or the gRPC library.
Examples:
- The name resolver received an update.
- The balancer updated its picker.
- Significant gRPC state is changing.
At verbosity of 0 (the default), any single info message should not be output
more than once every 5 minutes under normal operation.
### Warning
Warning messages indicate problems that are non-fatal for the application, but
could lead to unexpected behavior or subsequent errors.
Examples:
- Resolver could not resolve target name.
- Error received while connecting to a server.
- Lost or corrupt connection with remote endpoint.
### Error
Error messages represent errors in the usage of gRPC that cannot be returned to
the application as errors, or internal gRPC-Go errors that are recoverable.
Internal errors are detected during gRPC tests and will result in test failures.
Examples:
- Invalid arguments passed to a function that cannot return an error.
- An internal error that cannot be returned or would be inappropriate to return
to the user.
### Fatal
Fatal errors are severe internal errors that are unrecoverable. These lead
directly to panics, and are avoided as much as possible.
Example:
- Internal invariant was violated.
- User attempted an action that cannot return an error gracefully, but would
lead to an invalid state if performed.

View File

@ -1,15 +0,0 @@
# Proxy
HTTP CONNECT proxies are supported by default in gRPC. The proxy address can be
specified by the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or
the lowercase versions thereof).
## Custom proxy
Currently, proxy support is implemented in the default dialer. It does one more
handshake (a CONNECT handshake in the case of HTTP CONNECT proxy) on the
connection before giving it to gRPC.
If the default proxy doesn't work for you, replace the default dialer with your
custom proxy dialer. This can be done using
[`WithDialer`](https://godoc.org/google.golang.org/grpc#WithDialer).

View File

@ -1,68 +0,0 @@
# RPC Errors
All service method handlers should return `nil` or errors from the
`status.Status` type. Clients have direct access to the errors.
Upon encountering an error, a gRPC server method handler should create a
`status.Status`. In typical usage, one would use [status.New][new-status]
passing in an appropriate [codes.Code][code] as well as a description of the
error to produce a `status.Status`. Calling [status.Err][status-err] converts
the `status.Status` type into an `error`. As a convenience method, there is also
[status.Error][status-error] which obviates the conversion step. Compare:
```
st := status.New(codes.NotFound, "some description")
err := st.Err()
// vs.
err := status.Error(codes.NotFound, "some description")
```
## Adding additional details to errors
In some cases, it may be necessary to add details for a particular error on the
server side. The [status.WithDetails][with-details] method exists for this
purpose. Clients may then read those details by first converting the plain
`error` type back to a [status.Status][status] and then using
[status.Details][details].
## Example
The [example][example] demonstrates the API discussed above and shows how to add
information about rate limits to the error message using `status.Status`.
To run the example, first start the server:
```
$ go run examples/rpc_errors/server/main.go
```
In a separate session, run the client:
```
$ go run examples/rpc_errors/client/main.go
```
On the first run of the client, all is well:
```
2018/03/12 19:39:33 Greeting: Hello world
```
Upon running the client a second time, the client exceeds the rate limit and
receives an error with details:
```
2018/03/19 16:42:01 Quota failure: violations:<subject:"name:world" description:"Limit one greeting per person" >
exit status 1
```
[status]: https://godoc.org/google.golang.org/grpc/status#Status
[new-status]: https://godoc.org/google.golang.org/grpc/status#New
[code]: https://godoc.org/google.golang.org/grpc/codes#Code
[with-details]: https://godoc.org/google.golang.org/grpc/status#Status.WithDetails
[details]: https://godoc.org/google.golang.org/grpc/status#Status.Details
[status-err]: https://godoc.org/google.golang.org/grpc/status#Status.Err
[status-error]: https://godoc.org/google.golang.org/grpc/status#Error
[example]: https://github.com/grpc/grpc-go/blob/master/examples/rpc_errors

View File

@ -1,151 +0,0 @@
# gRPC Server Reflection Tutorial
gRPC Server Reflection provides information about publicly-accessible gRPC
services on a server, and assists clients at runtime to construct RPC
requests and responses without precompiled service information. It is used by
gRPC CLI, which can be used to introspect server protos and send/receive test
RPCs.
## Enable Server Reflection
gRPC-go Server Reflection is implemented in package [reflection](https://github.com/grpc/grpc-go/tree/master/reflection). To enable server reflection, you need to import this package and register reflection service on your gRPC server.
For example, to enable server reflection in `example/helloworld`, we need to make the following changes:
```diff
--- a/examples/helloworld/greeter_server/main.go
+++ b/examples/helloworld/greeter_server/main.go
@@ -40,6 +40,7 @@ import (
"google.golang.org/grpc"
pb "google.golang.org/grpc/examples/helloworld/helloworld"
+ "google.golang.org/grpc/reflection"
)
const (
@@ -61,6 +62,8 @@ func main() {
}
s := grpc.NewServer()
pb.RegisterGreeterServer(s, &server{})
+ // Register reflection service on gRPC server.
+ reflection.Register(s)
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
```
We have made this change in `example/helloworld`, and we will use it as an example to show the use of gRPC server reflection and gRPC CLI in this tutorial.
## gRPC CLI
After enabling Server Reflection in a server application, you can use gRPC CLI to check its services.
gRPC CLI is only available in c++. Instructions on how to use gRPC CLI can be found at [command_line_tool.md](https://github.com/grpc/grpc/blob/master/doc/command_line_tool.md).
To build gRPC CLI:
```sh
git clone https://github.com/grpc/grpc
cd grpc
make grpc_cli
cd bins/opt # grpc_cli is in directory bins/opt/
```
## Use gRPC CLI to check services
First, start the helloworld server in grpc-go directory:
```sh
$ cd <grpc-go-directory>
$ go run examples/helloworld/greeter_server/main.go
```
Open a new terminal and make sure you are in the directory where grpc_cli lives:
```sh
$ cd <grpc-cpp-dirctory>/bins/opt
```
### List services
`grpc_cli ls` command lists services and methods exposed at a given port:
- List all the services exposed at a given port
```sh
$ ./grpc_cli ls localhost:50051
```
output:
```sh
helloworld.Greeter
grpc.reflection.v1alpha.ServerReflection
```
- List one service with details
`grpc_cli ls` command inspects a service given its full name (in the format of
\<package\>.\<service\>). It can print information with a long listing format
when `-l` flag is set. This flag can be used to get more details about a
service.
```sh
$ ./grpc_cli ls localhost:50051 helloworld.Greeter -l
```
output:
```sh
filename: helloworld.proto
package: helloworld;
service Greeter {
rpc SayHello(helloworld.HelloRequest) returns (helloworld.HelloReply) {}
}
```
### List methods
- List one method with details
`grpc_cli ls` command also inspects a method given its full name (in the
format of \<package\>.\<service\>.\<method\>).
```sh
$ ./grpc_cli ls localhost:50051 helloworld.Greeter.SayHello -l
```
output:
```sh
rpc SayHello(helloworld.HelloRequest) returns (helloworld.HelloReply) {}
```
### Inspect message types
We can use`grpc_cli type` command to inspect request/response types given the
full name of the type (in the format of \<package\>.\<type\>).
- Get information about the request type
```sh
$ ./grpc_cli type localhost:50051 helloworld.HelloRequest
```
output:
```sh
message HelloRequest {
optional string name = 1[json_name = "name"];
}
```
### Call a remote method
We can send RPCs to a server and get responses using `grpc_cli call` command.
- Call a unary method
```sh
$ ./grpc_cli call localhost:50051 SayHello "name: 'gRPC CLI'"
```
output:
```sh
message: "Hello gRPC CLI"
```

View File

@ -1,34 +0,0 @@
# Versioning and Releases
Note: This document references terminology defined at http://semver.org.
## Release Frequency
Regular MINOR releases of gRPC-Go are performed every six weeks. Patch releases
to the previous two MINOR releases may be performed on demand or if serious
security problems are discovered.
## Versioning Policy
The gRPC-Go versioning policy follows the Semantic Versioning 2.0.0
specification, with the following exceptions:
- A MINOR version will not _necessarily_ add new functionality.
- MINOR releases will not break backward compatibility, except in the following
circumstances:
- An API was marked as EXPERIMENTAL upon its introduction.
- An API was marked as DEPRECATED in the initial MAJOR release.
- An API is inherently flawed and cannot provide correct or secure behavior.
In these cases, APIs MAY be changed or removed without a MAJOR release.
Otherwise, backward compatibility will be preserved by MINOR releases.
For an API marked as DEPRECATED, an alternative will be available (if
appropriate) for at least three months prior to its removal.
## Release History
Please see our release history on GitHub:
https://github.com/grpc/grpc-go/releases

View File

@ -1,60 +0,0 @@
all: vet test testrace testappengine
build: deps
go build google.golang.org/grpc/...
clean:
go clean -i google.golang.org/grpc/...
deps:
go get -d -v google.golang.org/grpc/...
proto:
@ if ! which protoc > /dev/null; then \
echo "error: protoc not installed" >&2; \
exit 1; \
fi
go generate google.golang.org/grpc/...
test: testdeps
go test -cpu 1,4 -timeout 7m google.golang.org/grpc/...
testappengine: testappenginedeps
goapp test -cpu 1,4 -timeout 7m google.golang.org/grpc/...
testappenginedeps:
goapp get -d -v -t -tags 'appengine appenginevm' google.golang.org/grpc/...
testdeps:
go get -d -v -t google.golang.org/grpc/...
testrace: testdeps
go test -race -cpu 1,4 -timeout 7m google.golang.org/grpc/...
updatedeps:
go get -d -v -u -f google.golang.org/grpc/...
updatetestdeps:
go get -d -v -t -u -f google.golang.org/grpc/...
vet: vetdeps
./vet.sh
vetdeps:
./vet.sh -install
.PHONY: \
all \
build \
clean \
deps \
proto \
test \
testappengine \
testappenginedeps \
testdeps \
testrace \
updatedeps \
updatetestdeps \
vet \
vetdeps

View File

@ -1,67 +0,0 @@
# gRPC-Go
[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc) [![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go)
The Go implementation of [gRPC](https://grpc.io/): A high performance, open source, general RPC framework that puts mobile and HTTP/2 first. For more information see the [gRPC Quick Start: Go](https://grpc.io/docs/quickstart/go.html) guide.
Installation
------------
To install this package, you need to install Go and setup your Go workspace on your computer. The simplest way to install the library is to run:
```
$ go get -u google.golang.org/grpc
```
Prerequisites
-------------
gRPC-Go requires Go 1.9 or later.
Constraints
-----------
The grpc package should only depend on standard Go packages and a small number of exceptions. If your contribution introduces new dependencies which are NOT in the [list](http://godoc.org/google.golang.org/grpc?imports), you need a discussion with gRPC-Go authors and consultants.
Documentation
-------------
See [API documentation](https://godoc.org/google.golang.org/grpc) for package and API descriptions and find examples in the [examples directory](examples/).
Performance
-----------
See the current benchmarks for some of the languages supported in [this dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696).
Status
------
General Availability [Google Cloud Platform Launch Stages](https://cloud.google.com/terms/launch-stages).
FAQ
---
#### Compiling error, undefined: grpc.SupportPackageIsVersion
Please update proto package, gRPC package and rebuild the proto files:
- `go get -u github.com/golang/protobuf/{proto,protoc-gen-go}`
- `go get -u google.golang.org/grpc`
- `protoc --go_out=plugins=grpc:. *.proto`
#### How to turn on logging
The default logger is controlled by the environment variables. Turn everything
on by setting:
```
GRPC_GO_LOG_VERBOSITY_LEVEL=99 GRPC_GO_LOG_SEVERITY_LEVEL=info
```
#### The RPC failed with error `"code = Unavailable desc = transport is closing"`
This error means the connection the RPC is using was closed, and there are many
possible reasons, including:
1. mis-configured transport credentials, connection failed on handshaking
1. bytes disrupted, possibly by a proxy in between
1. server shutdown
It can be tricky to debug this because the error happens on the client side but
the root cause of the connection being closed is on the server side. Turn on
logging on __both client and server__, and see if there are any transport
errors.

View File

@ -1,839 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/lb/v1/load_balancer.proto
package grpc_lb_v1 // import "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import duration "github.com/golang/protobuf/ptypes/duration"
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type LoadBalanceRequest struct {
// Types that are valid to be assigned to LoadBalanceRequestType:
// *LoadBalanceRequest_InitialRequest
// *LoadBalanceRequest_ClientStats
LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceRequest) Reset() { *m = LoadBalanceRequest{} }
func (m *LoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceRequest) ProtoMessage() {}
func (*LoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{0}
}
func (m *LoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceRequest.Unmarshal(m, b)
}
func (m *LoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceRequest.Merge(dst, src)
}
func (m *LoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_LoadBalanceRequest.Size(m)
}
func (m *LoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceRequest proto.InternalMessageInfo
type isLoadBalanceRequest_LoadBalanceRequestType interface {
isLoadBalanceRequest_LoadBalanceRequestType()
}
type LoadBalanceRequest_InitialRequest struct {
InitialRequest *InitialLoadBalanceRequest `protobuf:"bytes,1,opt,name=initial_request,json=initialRequest,proto3,oneof"`
}
type LoadBalanceRequest_ClientStats struct {
ClientStats *ClientStats `protobuf:"bytes,2,opt,name=client_stats,json=clientStats,proto3,oneof"`
}
func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if m != nil {
return m.LoadBalanceRequestType
}
return nil
}
func (m *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
return nil
}
func (m *LoadBalanceRequest) GetClientStats() *ClientStats {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceRequest_OneofMarshaler, _LoadBalanceRequest_OneofUnmarshaler, _LoadBalanceRequest_OneofSizer, []interface{}{
(*LoadBalanceRequest_InitialRequest)(nil),
(*LoadBalanceRequest_ClientStats)(nil),
}
}
func _LoadBalanceRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialRequest); err != nil {
return err
}
case *LoadBalanceRequest_ClientStats:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ClientStats); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceRequest.LoadBalanceRequestType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceRequest)
switch tag {
case 1: // load_balance_request_type.initial_request
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceRequest)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_InitialRequest{msg}
return true, err
case 2: // load_balance_request_type.client_stats
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ClientStats)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_ClientStats{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceRequest_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
s := proto.Size(x.InitialRequest)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceRequest_ClientStats:
s := proto.Size(x.ClientStats)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceRequest struct {
// The name of the load balanced service (e.g., service.googleapis.com). Its
// length should be less than 256 bytes.
// The name might include a port number. How to handle the port number is up
// to the balancer.
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceRequest) Reset() { *m = InitialLoadBalanceRequest{} }
func (m *InitialLoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceRequest) ProtoMessage() {}
func (*InitialLoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{1}
}
func (m *InitialLoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceRequest.Unmarshal(m, b)
}
func (m *InitialLoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceRequest.Merge(dst, src)
}
func (m *InitialLoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceRequest.Size(m)
}
func (m *InitialLoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceRequest proto.InternalMessageInfo
func (m *InitialLoadBalanceRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
// Contains the number of calls finished for a particular load balance token.
type ClientStatsPerToken struct {
// See Server.load_balance_token.
LoadBalanceToken string `protobuf:"bytes,1,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// The total number of RPCs that finished associated with the token.
NumCalls int64 `protobuf:"varint,2,opt,name=num_calls,json=numCalls,proto3" json:"num_calls,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStatsPerToken) Reset() { *m = ClientStatsPerToken{} }
func (m *ClientStatsPerToken) String() string { return proto.CompactTextString(m) }
func (*ClientStatsPerToken) ProtoMessage() {}
func (*ClientStatsPerToken) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{2}
}
func (m *ClientStatsPerToken) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStatsPerToken.Unmarshal(m, b)
}
func (m *ClientStatsPerToken) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStatsPerToken.Marshal(b, m, deterministic)
}
func (dst *ClientStatsPerToken) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStatsPerToken.Merge(dst, src)
}
func (m *ClientStatsPerToken) XXX_Size() int {
return xxx_messageInfo_ClientStatsPerToken.Size(m)
}
func (m *ClientStatsPerToken) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStatsPerToken.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStatsPerToken proto.InternalMessageInfo
func (m *ClientStatsPerToken) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *ClientStatsPerToken) GetNumCalls() int64 {
if m != nil {
return m.NumCalls
}
return 0
}
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
type ClientStats struct {
// The timestamp of generating the report.
Timestamp *timestamp.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// The total number of RPCs that started.
NumCallsStarted int64 `protobuf:"varint,2,opt,name=num_calls_started,json=numCallsStarted,proto3" json:"num_calls_started,omitempty"`
// The total number of RPCs that finished.
NumCallsFinished int64 `protobuf:"varint,3,opt,name=num_calls_finished,json=numCallsFinished,proto3" json:"num_calls_finished,omitempty"`
// The total number of RPCs that failed to reach a server except dropped RPCs.
NumCallsFinishedWithClientFailedToSend int64 `protobuf:"varint,6,opt,name=num_calls_finished_with_client_failed_to_send,json=numCallsFinishedWithClientFailedToSend,proto3" json:"num_calls_finished_with_client_failed_to_send,omitempty"`
// The total number of RPCs that finished and are known to have been received
// by a server.
NumCallsFinishedKnownReceived int64 `protobuf:"varint,7,opt,name=num_calls_finished_known_received,json=numCallsFinishedKnownReceived,proto3" json:"num_calls_finished_known_received,omitempty"`
// The list of dropped calls.
CallsFinishedWithDrop []*ClientStatsPerToken `protobuf:"bytes,8,rep,name=calls_finished_with_drop,json=callsFinishedWithDrop,proto3" json:"calls_finished_with_drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetTimestamp() *timestamp.Timestamp {
if m != nil {
return m.Timestamp
}
return nil
}
func (m *ClientStats) GetNumCallsStarted() int64 {
if m != nil {
return m.NumCallsStarted
}
return 0
}
func (m *ClientStats) GetNumCallsFinished() int64 {
if m != nil {
return m.NumCallsFinished
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithClientFailedToSend() int64 {
if m != nil {
return m.NumCallsFinishedWithClientFailedToSend
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedKnownReceived() int64 {
if m != nil {
return m.NumCallsFinishedKnownReceived
}
return 0
}
func (m *ClientStats) GetCallsFinishedWithDrop() []*ClientStatsPerToken {
if m != nil {
return m.CallsFinishedWithDrop
}
return nil
}
type LoadBalanceResponse struct {
// Types that are valid to be assigned to LoadBalanceResponseType:
// *LoadBalanceResponse_InitialResponse
// *LoadBalanceResponse_ServerList
LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceResponse) Reset() { *m = LoadBalanceResponse{} }
func (m *LoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceResponse) ProtoMessage() {}
func (*LoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{4}
}
func (m *LoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceResponse.Unmarshal(m, b)
}
func (m *LoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceResponse.Merge(dst, src)
}
func (m *LoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_LoadBalanceResponse.Size(m)
}
func (m *LoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceResponse proto.InternalMessageInfo
type isLoadBalanceResponse_LoadBalanceResponseType interface {
isLoadBalanceResponse_LoadBalanceResponseType()
}
type LoadBalanceResponse_InitialResponse struct {
InitialResponse *InitialLoadBalanceResponse `protobuf:"bytes,1,opt,name=initial_response,json=initialResponse,proto3,oneof"`
}
type LoadBalanceResponse_ServerList struct {
ServerList *ServerList `protobuf:"bytes,2,opt,name=server_list,json=serverList,proto3,oneof"`
}
func (*LoadBalanceResponse_InitialResponse) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if m != nil {
return m.LoadBalanceResponseType
}
return nil
}
func (m *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
return nil
}
func (m *LoadBalanceResponse) GetServerList() *ServerList {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceResponse) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceResponse_OneofMarshaler, _LoadBalanceResponse_OneofUnmarshaler, _LoadBalanceResponse_OneofSizer, []interface{}{
(*LoadBalanceResponse_InitialResponse)(nil),
(*LoadBalanceResponse_ServerList)(nil),
}
}
func _LoadBalanceResponse_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialResponse); err != nil {
return err
}
case *LoadBalanceResponse_ServerList:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ServerList); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceResponse.LoadBalanceResponseType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceResponse_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceResponse)
switch tag {
case 1: // load_balance_response_type.initial_response
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceResponse)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_InitialResponse{msg}
return true, err
case 2: // load_balance_response_type.server_list
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ServerList)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_ServerList{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceResponse_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
s := proto.Size(x.InitialResponse)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceResponse_ServerList:
s := proto.Size(x.ServerList)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceResponse struct {
// This is an application layer redirect that indicates the client should use
// the specified server for load balancing. When this field is non-empty in
// the response, the client should open a separate connection to the
// load_balancer_delegate and call the BalanceLoad method. Its length should
// be less than 64 bytes.
LoadBalancerDelegate string `protobuf:"bytes,1,opt,name=load_balancer_delegate,json=loadBalancerDelegate,proto3" json:"load_balancer_delegate,omitempty"`
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
ClientStatsReportInterval *duration.Duration `protobuf:"bytes,2,opt,name=client_stats_report_interval,json=clientStatsReportInterval,proto3" json:"client_stats_report_interval,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceResponse) Reset() { *m = InitialLoadBalanceResponse{} }
func (m *InitialLoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceResponse) ProtoMessage() {}
func (*InitialLoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{5}
}
func (m *InitialLoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceResponse.Unmarshal(m, b)
}
func (m *InitialLoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceResponse.Merge(dst, src)
}
func (m *InitialLoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceResponse.Size(m)
}
func (m *InitialLoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceResponse proto.InternalMessageInfo
func (m *InitialLoadBalanceResponse) GetLoadBalancerDelegate() string {
if m != nil {
return m.LoadBalancerDelegate
}
return ""
}
func (m *InitialLoadBalanceResponse) GetClientStatsReportInterval() *duration.Duration {
if m != nil {
return m.ClientStatsReportInterval
}
return nil
}
type ServerList struct {
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
Servers []*Server `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerList) Reset() { *m = ServerList{} }
func (m *ServerList) String() string { return proto.CompactTextString(m) }
func (*ServerList) ProtoMessage() {}
func (*ServerList) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{6}
}
func (m *ServerList) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerList.Unmarshal(m, b)
}
func (m *ServerList) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerList.Marshal(b, m, deterministic)
}
func (dst *ServerList) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerList.Merge(dst, src)
}
func (m *ServerList) XXX_Size() int {
return xxx_messageInfo_ServerList.Size(m)
}
func (m *ServerList) XXX_DiscardUnknown() {
xxx_messageInfo_ServerList.DiscardUnknown(m)
}
var xxx_messageInfo_ServerList proto.InternalMessageInfo
func (m *ServerList) GetServers() []*Server {
if m != nil {
return m.Servers
}
return nil
}
// Contains server information. When the drop field is not true, use the other
// fields.
type Server struct {
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"`
// A resolved port number for the server.
Port int32 `protobuf:"varint,2,opt,name=port,proto3" json:"port,omitempty"`
// An opaque but printable token for load reporting. The client must include
// the token of the picked server into the initial metadata when it starts a
// call to that server. The token is used by the server to verify the request
// and to allow the server to report load to the gRPC LB system. The token is
// also used in client stats for reporting dropped calls.
//
// Its length can be variable but must be less than 50 bytes.
LoadBalanceToken string `protobuf:"bytes,3,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// Indicates whether this particular request should be dropped by the client.
// If the request is dropped, there will be a corresponding entry in
// ClientStats.calls_finished_with_drop.
Drop bool `protobuf:"varint,4,opt,name=drop,proto3" json:"drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Server) Reset() { *m = Server{} }
func (m *Server) String() string { return proto.CompactTextString(m) }
func (*Server) ProtoMessage() {}
func (*Server) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{7}
}
func (m *Server) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Server.Unmarshal(m, b)
}
func (m *Server) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Server.Marshal(b, m, deterministic)
}
func (dst *Server) XXX_Merge(src proto.Message) {
xxx_messageInfo_Server.Merge(dst, src)
}
func (m *Server) XXX_Size() int {
return xxx_messageInfo_Server.Size(m)
}
func (m *Server) XXX_DiscardUnknown() {
xxx_messageInfo_Server.DiscardUnknown(m)
}
var xxx_messageInfo_Server proto.InternalMessageInfo
func (m *Server) GetIpAddress() []byte {
if m != nil {
return m.IpAddress
}
return nil
}
func (m *Server) GetPort() int32 {
if m != nil {
return m.Port
}
return 0
}
func (m *Server) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *Server) GetDrop() bool {
if m != nil {
return m.Drop
}
return false
}
func init() {
proto.RegisterType((*LoadBalanceRequest)(nil), "grpc.lb.v1.LoadBalanceRequest")
proto.RegisterType((*InitialLoadBalanceRequest)(nil), "grpc.lb.v1.InitialLoadBalanceRequest")
proto.RegisterType((*ClientStatsPerToken)(nil), "grpc.lb.v1.ClientStatsPerToken")
proto.RegisterType((*ClientStats)(nil), "grpc.lb.v1.ClientStats")
proto.RegisterType((*LoadBalanceResponse)(nil), "grpc.lb.v1.LoadBalanceResponse")
proto.RegisterType((*InitialLoadBalanceResponse)(nil), "grpc.lb.v1.InitialLoadBalanceResponse")
proto.RegisterType((*ServerList)(nil), "grpc.lb.v1.ServerList")
proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// LoadBalancerClient is the client API for LoadBalancer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type LoadBalancerClient interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error)
}
type loadBalancerClient struct {
cc *grpc.ClientConn
}
func NewLoadBalancerClient(cc *grpc.ClientConn) LoadBalancerClient {
return &loadBalancerClient{cc}
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) {
stream, err := c.cc.NewStream(ctx, &_LoadBalancer_serviceDesc.Streams[0], "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &loadBalancerBalanceLoadClient{stream}
return x, nil
}
type LoadBalancer_BalanceLoadClient interface {
Send(*LoadBalanceRequest) error
Recv() (*LoadBalanceResponse, error)
grpc.ClientStream
}
type loadBalancerBalanceLoadClient struct {
grpc.ClientStream
}
func (x *loadBalancerBalanceLoadClient) Send(m *LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadClient) Recv() (*LoadBalanceResponse, error) {
m := new(LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// LoadBalancerServer is the server API for LoadBalancer service.
type LoadBalancerServer interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(LoadBalancer_BalanceLoadServer) error
}
func RegisterLoadBalancerServer(s *grpc.Server, srv LoadBalancerServer) {
s.RegisterService(&_LoadBalancer_serviceDesc, srv)
}
func _LoadBalancer_BalanceLoad_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(LoadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream})
}
type LoadBalancer_BalanceLoadServer interface {
Send(*LoadBalanceResponse) error
Recv() (*LoadBalanceRequest, error)
grpc.ServerStream
}
type loadBalancerBalanceLoadServer struct {
grpc.ServerStream
}
func (x *loadBalancerBalanceLoadServer) Send(m *LoadBalanceResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadServer) Recv() (*LoadBalanceRequest, error) {
m := new(LoadBalanceRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _LoadBalancer_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.lb.v1.LoadBalancer",
HandlerType: (*LoadBalancerServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "BalanceLoad",
Handler: _LoadBalancer_BalanceLoad_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "grpc/lb/v1/load_balancer.proto",
}
func init() {
proto.RegisterFile("grpc/lb/v1/load_balancer.proto", fileDescriptor_load_balancer_12026aec3f0251ba)
}
var fileDescriptor_load_balancer_12026aec3f0251ba = []byte{
// 752 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xdd, 0x6e, 0x23, 0x35,
0x14, 0xee, 0x90, 0x69, 0x36, 0x39, 0x29, 0x34, 0xeb, 0x85, 0x65, 0x92, 0xdd, 0x6d, 0x4b, 0x24,
0x56, 0x11, 0x2a, 0x13, 0x52, 0xb8, 0x00, 0x89, 0x0b, 0x48, 0xab, 0x2a, 0x2d, 0xbd, 0x88, 0x9c,
0x4a, 0x45, 0x95, 0x90, 0x99, 0xc9, 0xb8, 0xa9, 0x55, 0xc7, 0x1e, 0x3c, 0x4e, 0x2a, 0xae, 0x79,
0x1f, 0xc4, 0x2b, 0x20, 0x5e, 0x0c, 0x8d, 0xed, 0x49, 0xa6, 0x49, 0xa3, 0xbd, 0xca, 0xf8, 0x9c,
0xcf, 0xdf, 0xf9, 0xfd, 0x1c, 0x38, 0x98, 0xaa, 0x74, 0xd2, 0xe3, 0x71, 0x6f, 0xd1, 0xef, 0x71,
0x19, 0x25, 0x24, 0x8e, 0x78, 0x24, 0x26, 0x54, 0x85, 0xa9, 0x92, 0x5a, 0x22, 0xc8, 0xfd, 0x21,
0x8f, 0xc3, 0x45, 0xbf, 0x7d, 0x30, 0x95, 0x72, 0xca, 0x69, 0xcf, 0x78, 0xe2, 0xf9, 0x5d, 0x2f,
0x99, 0xab, 0x48, 0x33, 0x29, 0x2c, 0xb6, 0x7d, 0xb8, 0xee, 0xd7, 0x6c, 0x46, 0x33, 0x1d, 0xcd,
0x52, 0x0b, 0xe8, 0xfc, 0xeb, 0x01, 0xba, 0x92, 0x51, 0x32, 0xb0, 0x31, 0x30, 0xfd, 0x63, 0x4e,
0x33, 0x8d, 0x46, 0xb0, 0xcf, 0x04, 0xd3, 0x2c, 0xe2, 0x44, 0x59, 0x53, 0xe0, 0x1d, 0x79, 0xdd,
0xc6, 0xc9, 0x97, 0xe1, 0x2a, 0x7a, 0x78, 0x61, 0x21, 0x9b, 0xf7, 0x87, 0x3b, 0xf8, 0x13, 0x77,
0xbf, 0x60, 0xfc, 0x11, 0xf6, 0x26, 0x9c, 0x51, 0xa1, 0x49, 0xa6, 0x23, 0x9d, 0x05, 0x1f, 0x19,
0xba, 0xcf, 0xcb, 0x74, 0xa7, 0xc6, 0x3f, 0xce, 0xdd, 0xc3, 0x1d, 0xdc, 0x98, 0xac, 0x8e, 0x83,
0x37, 0xd0, 0x2a, 0xb7, 0xa2, 0x48, 0x8a, 0xe8, 0x3f, 0x53, 0xda, 0xe9, 0x41, 0x6b, 0x6b, 0x26,
0x08, 0x81, 0x2f, 0xa2, 0x19, 0x35, 0xe9, 0xd7, 0xb1, 0xf9, 0xee, 0xfc, 0x0e, 0xaf, 0x4a, 0xb1,
0x46, 0x54, 0x5d, 0xcb, 0x07, 0x2a, 0xd0, 0x31, 0xa0, 0x27, 0x41, 0x74, 0x6e, 0x75, 0x17, 0x9b,
0x7c, 0x45, 0x6d, 0xd1, 0x6f, 0xa0, 0x2e, 0xe6, 0x33, 0x32, 0x89, 0x38, 0xb7, 0xd5, 0x54, 0x70,
0x4d, 0xcc, 0x67, 0xa7, 0xf9, 0xb9, 0xf3, 0x4f, 0x05, 0x1a, 0xa5, 0x10, 0xe8, 0x7b, 0xa8, 0x2f,
0x3b, 0xef, 0x3a, 0xd9, 0x0e, 0xed, 0x6c, 0xc2, 0x62, 0x36, 0xe1, 0x75, 0x81, 0xc0, 0x2b, 0x30,
0xfa, 0x0a, 0x5e, 0x2e, 0xc3, 0xe4, 0xad, 0x53, 0x9a, 0x26, 0x2e, 0xdc, 0x7e, 0x11, 0x6e, 0x6c,
0xcd, 0x79, 0x01, 0x2b, 0xec, 0x1d, 0x13, 0x2c, 0xbb, 0xa7, 0x49, 0x50, 0x31, 0xe0, 0x66, 0x01,
0x3e, 0x77, 0x76, 0xf4, 0x1b, 0x7c, 0xbd, 0x89, 0x26, 0x8f, 0x4c, 0xdf, 0x13, 0x37, 0xa9, 0xbb,
0x88, 0x71, 0x9a, 0x10, 0x2d, 0x49, 0x46, 0x45, 0x12, 0x54, 0x0d, 0xd1, 0xfb, 0x75, 0xa2, 0x1b,
0xa6, 0xef, 0x6d, 0xad, 0xe7, 0x06, 0x7f, 0x2d, 0xc7, 0x54, 0x24, 0x68, 0x08, 0x5f, 0x3c, 0x43,
0xff, 0x20, 0xe4, 0xa3, 0x20, 0x8a, 0x4e, 0x28, 0x5b, 0xd0, 0x24, 0x78, 0x61, 0x28, 0xdf, 0xad,
0x53, 0xfe, 0x92, 0xa3, 0xb0, 0x03, 0xa1, 0x5f, 0x21, 0x78, 0x2e, 0xc9, 0x44, 0xc9, 0x34, 0xa8,
0x1d, 0x55, 0xba, 0x8d, 0x93, 0xc3, 0x2d, 0x6b, 0x54, 0x8c, 0x16, 0x7f, 0x36, 0x59, 0xcf, 0xf8,
0x4c, 0xc9, 0xf4, 0xd2, 0xaf, 0xf9, 0xcd, 0xdd, 0x4b, 0xbf, 0xb6, 0xdb, 0xac, 0x76, 0xfe, 0xf3,
0xe0, 0xd5, 0x93, 0xfd, 0xc9, 0x52, 0x29, 0x32, 0x8a, 0xc6, 0xd0, 0x5c, 0x49, 0xc1, 0xda, 0xdc,
0x04, 0xdf, 0x7f, 0x48, 0x0b, 0x16, 0x3d, 0xdc, 0xc1, 0xfb, 0x4b, 0x31, 0x38, 0xd2, 0x1f, 0xa0,
0x91, 0x51, 0xb5, 0xa0, 0x8a, 0x70, 0x96, 0x69, 0x27, 0x86, 0xd7, 0x65, 0xbe, 0xb1, 0x71, 0x5f,
0x31, 0x23, 0x26, 0xc8, 0x96, 0xa7, 0xc1, 0x5b, 0x68, 0xaf, 0x49, 0xc1, 0x72, 0x5a, 0x2d, 0xfc,
0xed, 0x41, 0x7b, 0x7b, 0x2a, 0xe8, 0x3b, 0x78, 0xfd, 0xe4, 0x49, 0x21, 0x09, 0xe5, 0x74, 0x1a,
0xe9, 0x42, 0x1f, 0x9f, 0x96, 0xd6, 0x5c, 0x9d, 0x39, 0x1f, 0xba, 0x85, 0xb7, 0x65, 0xed, 0x12,
0x45, 0x53, 0xa9, 0x34, 0x61, 0x42, 0x53, 0xb5, 0x88, 0xb8, 0x4b, 0xbf, 0xb5, 0xb1, 0xd0, 0x67,
0xee, 0x31, 0xc2, 0xad, 0x92, 0x96, 0xb1, 0xb9, 0x7c, 0xe1, 0xee, 0x76, 0x7e, 0x02, 0x58, 0x95,
0x8a, 0x8e, 0xe1, 0x85, 0x2d, 0x35, 0x0b, 0x3c, 0x33, 0x59, 0xb4, 0xd9, 0x13, 0x5c, 0x40, 0x2e,
0xfd, 0x5a, 0xa5, 0xe9, 0x77, 0xfe, 0xf2, 0xa0, 0x6a, 0x3d, 0xe8, 0x1d, 0x00, 0x4b, 0x49, 0x94,
0x24, 0x8a, 0x66, 0x99, 0x29, 0x69, 0x0f, 0xd7, 0x59, 0xfa, 0xb3, 0x35, 0xe4, 0x6f, 0x41, 0x1e,
0xdb, 0xe4, 0xbb, 0x8b, 0xcd, 0xf7, 0x16, 0xd1, 0x57, 0xb6, 0x88, 0x1e, 0x81, 0x6f, 0xd6, 0xce,
0x3f, 0xf2, 0xba, 0x35, 0x6c, 0xbe, 0xed, 0xfa, 0x9c, 0xc4, 0xb0, 0x57, 0x6a, 0xb8, 0x42, 0x18,
0x1a, 0xee, 0x3b, 0x37, 0xa3, 0x83, 0x72, 0x1d, 0x9b, 0xcf, 0x54, 0xfb, 0x70, 0xab, 0xdf, 0x4e,
0xae, 0xeb, 0x7d, 0xe3, 0x0d, 0x6e, 0xe0, 0x63, 0x26, 0x4b, 0xc0, 0xc1, 0xcb, 0x72, 0xc8, 0x51,
0xde, 0xf6, 0x91, 0x77, 0xdb, 0x77, 0x63, 0x98, 0x4a, 0x1e, 0x89, 0x69, 0x28, 0xd5, 0xb4, 0x67,
0xfe, 0x51, 0x8a, 0x99, 0x9b, 0x13, 0x8f, 0xcd, 0x0f, 0xe1, 0x31, 0x59, 0xf4, 0xe3, 0xaa, 0x19,
0xd9, 0xb7, 0xff, 0x07, 0x00, 0x00, 0xff, 0xff, 0x81, 0x14, 0xee, 0xd1, 0x7b, 0x06, 0x00, 0x00,
}

View File

@ -1,392 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package grpclb defines a grpclb balancer.
//
// To install grpclb balancer, import this package as:
// import _ "google.golang.org/grpc/balancer/grpclb"
package grpclb
import (
"context"
"errors"
"strconv"
"strings"
"sync"
"time"
durationpb "github.com/golang/protobuf/ptypes/duration"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/resolver"
)
const (
lbTokeyKey = "lb-token"
defaultFallbackTimeout = 10 * time.Second
grpclbName = "grpclb"
)
var (
// defaultBackoffConfig configures the backoff strategy that's used when the
// init handshake in the RPC is unsuccessful. It's not for the clientconn
// reconnect backoff.
//
// It has the same value as the default grpc.DefaultBackoffConfig.
//
// TODO: make backoff configurable.
defaultBackoffConfig = backoff.Exponential{
MaxDelay: 120 * time.Second,
}
errServerTerminatedConnection = errors.New("grpclb: failed to recv server list: server terminated connection")
)
func convertDuration(d *durationpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}
// Client API for LoadBalancer service.
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
cc *grpc.ClientConn
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (*balanceLoadClientStream, error) {
desc := &grpc.StreamDesc{
StreamName: "BalanceLoad",
ServerStreams: true,
ClientStreams: true,
}
stream, err := c.cc.NewStream(ctx, desc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &balanceLoadClientStream{stream}
return x, nil
}
type balanceLoadClientStream struct {
grpc.ClientStream
}
func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) {
m := new(lbpb.LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func init() {
balancer.Register(newLBBuilder())
}
// newLBBuilder creates a builder for grpclb.
func newLBBuilder() balancer.Builder {
return newLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
}
// newLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// fallbackTimeout. If no response is received from the remote balancer within
// fallbackTimeout, the backend addresses from the resolved address list will be
// used.
//
// Only call this function when a non-default fallback timeout is needed.
func newLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
return &lbBuilder{
fallbackTimeout: fallbackTimeout,
}
}
type lbBuilder struct {
fallbackTimeout time.Duration
}
func (b *lbBuilder) Name() string {
return grpclbName
}
func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
// This generates a manual resolver builder with a random scheme. This
// scheme will be used to dial to remote LB, so we can send filtered address
// updates to remote LB ClientConn using this manual resolver.
scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36)
r := &lbManualResolver{scheme: scheme, ccb: cc}
var target string
targetSplitted := strings.Split(cc.Target(), ":///")
if len(targetSplitted) < 2 {
target = cc.Target()
} else {
target = targetSplitted[1]
}
lb := &lbBalancer{
cc: newLBCacheClientConn(cc),
target: target,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: r,
csEvltr: &balancer.ConnectivityStateEvaluator{},
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: newRPCStats(),
backoff: defaultBackoffConfig, // TODO: make backoff configurable.
}
var err error
if opt.CredsBundle != nil {
lb.grpclbClientConnCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBalancer)
if err != nil {
grpclog.Warningf("lbBalancer: client connection creds NewWithMode failed: %v", err)
}
lb.grpclbBackendCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBackendFromBalancer)
if err != nil {
grpclog.Warningf("lbBalancer: backend creds NewWithMode failed: %v", err)
}
}
return lb
}
type lbBalancer struct {
cc *lbCacheClientConn
target string
opt balancer.BuildOptions
// grpclbClientConnCreds is the creds bundle to be used to connect to grpclb
// servers. If it's nil, use the TransportCredentials from BuildOptions
// instead.
grpclbClientConnCreds credentials.Bundle
// grpclbBackendCreds is the creds bundle to be used for addresses that are
// returned by grpclb server. If it's nil, don't set anything when creating
// SubConns.
grpclbBackendCreds credentials.Bundle
fallbackTimeout time.Duration
doneCh chan struct{}
// manualResolver is used in the remote LB ClientConn inside grpclb. When
// resolved address updates are received by grpclb, filtered updates will be
// send to remote LB ClientConn through this resolver.
manualResolver *lbManualResolver
// The ClientConn to talk to the remote balancer.
ccRemoteLB *grpc.ClientConn
// backoff for calling remote balancer.
backoff backoff.Strategy
// Support client side load reporting. Each picker gets a reference to this,
// and will update its content.
clientStats *rpcStats
mu sync.Mutex // guards everything following.
// The full server list including drops, used to check if the newly received
// serverList contains anything new. Each generate picker will also have
// reference to this list to do the first layer pick.
fullServerList []*lbpb.Server
// All backends addresses, with metadata set to nil. This list contains all
// backend addresses in the same order and with the same duplicates as in
// serverlist. When generating picker, a SubConn slice with the same order
// but with only READY SCs will be gerenated.
backendAddrs []resolver.Address
// Roundrobin functionalities.
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
picker balancer.Picker
// Support fallback to resolved backend addresses if there's no response
// from remote balancer within fallbackTimeout.
fallbackTimerExpired bool
serverListReceived bool
// resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set
// when resolved address updates are received, and read in the goroutine
// handling fallback.
resolvedBackendAddrs []resolver.Address
}
// regeneratePicker takes a snapshot of the balancer, and generates a picker from
// it. The picker
// - always returns ErrTransientFailure if the balancer is in TransientFailure,
// - does two layer roundrobin pick otherwise.
// Caller must hold lb.mu.
func (lb *lbBalancer) regeneratePicker() {
if lb.state == connectivity.TransientFailure {
lb.picker = &errPicker{err: balancer.ErrTransientFailure}
return
}
var readySCs []balancer.SubConn
for _, a := range lb.backendAddrs {
if sc, ok := lb.subConns[a]; ok {
if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready {
readySCs = append(readySCs, sc)
}
}
}
if len(lb.fullServerList) <= 0 {
if len(readySCs) <= 0 {
lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable}
return
}
lb.picker = &rrPicker{subConns: readySCs}
return
}
lb.picker = &lbPicker{
serverList: lb.fullServerList,
subConns: readySCs,
stats: lb.clientStats,
}
}
func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s)
lb.mu.Lock()
defer lb.mu.Unlock()
oldS, ok := lb.scStates[sc]
if !ok {
grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
return
}
lb.scStates[sc] = s
switch s {
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(lb.scStates, sc)
}
oldAggrState := lb.state
lb.state = lb.csEvltr.RecordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
// - this sc became not-ready from ready
// - the aggregated state of balancer became TransientFailure from non-TransientFailure
// - the aggregated state of balancer became non-TransientFailure from TransientFailure
if (oldS == connectivity.Ready) != (s == connectivity.Ready) ||
(lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) {
lb.regeneratePicker()
}
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
// resolved backends (backends received from resolver, not from remote balancer)
// if no connection to remote balancers was successful.
func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) {
timer := time.NewTimer(fallbackTimeout)
defer timer.Stop()
select {
case <-timer.C:
case <-lb.doneCh:
return
}
lb.mu.Lock()
if lb.serverListReceived {
lb.mu.Unlock()
return
}
lb.fallbackTimerExpired = true
lb.refreshSubConns(lb.resolvedBackendAddrs, false)
lb.mu.Unlock()
}
// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB
// clientConn. The remoteLB clientConn will handle creating/removing remoteLB
// connections.
func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs)
if len(addrs) <= 0 {
return
}
var remoteBalancerAddrs, backendAddrs []resolver.Address
for _, a := range addrs {
if a.Type == resolver.GRPCLB {
remoteBalancerAddrs = append(remoteBalancerAddrs, a)
} else {
backendAddrs = append(backendAddrs, a)
}
}
if lb.ccRemoteLB == nil {
if len(remoteBalancerAddrs) <= 0 {
grpclog.Errorf("grpclb: no remote balancer address is available, should never happen")
return
}
// First time receiving resolved addresses, create a cc to remote
// balancers.
lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName)
// Start the fallback goroutine.
go lb.fallbackToBackendsAfter(lb.fallbackTimeout)
}
// cc to remote balancers uses lb.manualResolver. Send the updated remote
// balancer addresses to it through manualResolver.
lb.manualResolver.NewAddress(remoteBalancerAddrs)
lb.mu.Lock()
lb.resolvedBackendAddrs = backendAddrs
// If serverListReceived is true, connection to remote balancer was
// successful and there's no need to do fallback anymore.
// If fallbackTimerExpired is false, fallback hasn't happened yet.
if !lb.serverListReceived && lb.fallbackTimerExpired {
// This means we received a new list of resolved backends, and we are
// still in fallback mode. Need to update the list of backends we are
// using to the new list of backends.
lb.refreshSubConns(lb.resolvedBackendAddrs, false)
}
lb.mu.Unlock()
}
func (lb *lbBalancer) Close() {
select {
case <-lb.doneCh:
return
default:
}
close(lb.doneCh)
if lb.ccRemoteLB != nil {
lb.ccRemoteLB.Close()
}
lb.cc.close()
}

View File

@ -1,170 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"context"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// rpcStats is same as lbmpb.ClientStats, except that numCallsDropped is a map
// instead of a slice.
type rpcStats struct {
// Only access the following fields atomically.
numCallsStarted int64
numCallsFinished int64
numCallsFinishedWithClientFailedToSend int64
numCallsFinishedKnownReceived int64
mu sync.Mutex
// map load_balance_token -> num_calls_dropped
numCallsDropped map[string]int64
}
func newRPCStats() *rpcStats {
return &rpcStats{
numCallsDropped: make(map[string]int64),
}
}
// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats.
func (s *rpcStats) toClientStats() *lbpb.ClientStats {
stats := &lbpb.ClientStats{
NumCallsStarted: atomic.SwapInt64(&s.numCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.numCallsFinished, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.numCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.numCallsFinishedKnownReceived, 0),
}
s.mu.Lock()
dropped := s.numCallsDropped
s.numCallsDropped = make(map[string]int64)
s.mu.Unlock()
for token, count := range dropped {
stats.CallsFinishedWithDrop = append(stats.CallsFinishedWithDrop, &lbpb.ClientStatsPerToken{
LoadBalanceToken: token,
NumCalls: count,
})
}
return stats
}
func (s *rpcStats) drop(token string) {
atomic.AddInt64(&s.numCallsStarted, 1)
s.mu.Lock()
s.numCallsDropped[token]++
s.mu.Unlock()
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) failedToSend() {
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
type errPicker struct {
// Pick always returns this err.
err error
}
func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
return nil, nil, p.err
}
// rrPicker does roundrobin on subConns. It's typically used when there's no
// response from remote balancer, and grpclb falls back to the resolved
// backends.
//
// It guaranteed that len(subConns) > 0.
type rrPicker struct {
mu sync.Mutex
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
}
func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
return sc, nil, nil
}
// lbPicker does two layers of picks:
//
// First layer: roundrobin on all servers in serverList, including drops and backends.
// - If it picks a drop, the RPC will fail as being dropped.
// - If it picks a backend, do a second layer pick to pick the real backend.
//
// Second layer: roundrobin on all READY backends.
//
// It's guaranteed that len(serverList) > 0.
type lbPicker struct {
mu sync.Mutex
serverList []*lbpb.Server
serverListNext int
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
stats *rpcStats
}
func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
// Layer one roundrobin on serverList.
s := p.serverList[p.serverListNext]
p.serverListNext = (p.serverListNext + 1) % len(p.serverList)
// If it's a drop, return an error and fail the RPC.
if s.Drop {
p.stats.drop(s.LoadBalanceToken)
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
// If not a drop but there's no ready subConns.
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
// Return the next ready subConn in the list, also collect rpc stats.
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
done := func(info balancer.DoneInfo) {
if !info.BytesSent {
p.stats.failedToSend()
} else if info.BytesReceived {
p.stats.knownReceived()
}
}
return sc, done, nil
}

View File

@ -1,304 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"context"
"fmt"
"io"
"net"
"reflect"
"time"
timestamppb "github.com/golang/protobuf/ptypes/timestamp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
// processServerList updates balaner's internal state, create/remove SubConns
// and regenerates picker using the received serverList.
func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
grpclog.Infof("lbBalancer: processing server list: %+v", l)
lb.mu.Lock()
defer lb.mu.Unlock()
// Set serverListReceived to true so fallback will not take effect if it has
// not hit timeout.
lb.serverListReceived = true
// If the new server list == old server list, do nothing.
if reflect.DeepEqual(lb.fullServerList, l.Servers) {
grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring")
return
}
lb.fullServerList = l.Servers
var backendAddrs []resolver.Address
for i, s := range l.Servers {
if s.Drop {
continue
}
md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken)
ip := net.IP(s.IpAddress)
ipStr := ip.String()
if ip.To4() == nil {
// Add square brackets to ipv6 addresses, otherwise net.Dial() and
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
}
addr := resolver.Address{
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
grpclog.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
i, ipStr, s.Port, s.LoadBalanceToken)
backendAddrs = append(backendAddrs, addr)
}
// Call refreshSubConns to create/remove SubConns.
lb.refreshSubConns(backendAddrs, true)
// Regenerate and update picker no matter if there's update on backends (if
// any SubConn will be newed/removed). Because since the full serverList was
// different, there might be updates in drops or pick weights(different
// number of duplicates). We need to update picker with the fulllist.
//
// Now with cache, even if SubConn was newed/removed, there might be no
// state changes.
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool
// indicating whether the backendAddrs are different from the cached
// backendAddrs (whether any SubConn was newed/removed).
// Caller must hold lb.mu.
func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCLBServer bool) bool {
opts := balancer.NewSubConnOptions{}
if fromGRPCLBServer {
opts.CredsBundle = lb.grpclbBackendCreds
}
lb.backendAddrs = nil
var backendsUpdated bool
// addrsSet is the set converted from backendAddrs, it's used to quick
// lookup for an address.
addrsSet := make(map[resolver.Address]struct{})
// Create new SubConns.
for _, addr := range backendAddrs {
addrWithoutMD := addr
addrWithoutMD.Metadata = nil
addrsSet[addrWithoutMD] = struct{}{}
lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD)
if _, ok := lb.subConns[addrWithoutMD]; !ok {
backendsUpdated = true
// Use addrWithMD to create the SubConn.
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts)
if err != nil {
grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err)
continue
}
lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
if _, ok := lb.scStates[sc]; !ok {
// Only set state of new sc to IDLE. The state could already be
// READY for cached SubConns.
lb.scStates[sc] = connectivity.Idle
}
sc.Connect()
}
}
for a, sc := range lb.subConns {
// a was removed by resolver.
if _, ok := addrsSet[a]; !ok {
backendsUpdated = true
lb.cc.RemoveSubConn(sc)
delete(lb.subConns, a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in HandleSubConnStateChange.
}
}
return backendsUpdated
}
func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error {
for {
reply, err := s.Recv()
if err != nil {
if err == io.EOF {
return errServerTerminatedConnection
}
return fmt.Errorf("grpclb: failed to recv server list: %v", err)
}
if serverList := reply.GetServerList(); serverList != nil {
lb.processServerList(serverList)
}
}
}
func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-s.Context().Done():
return
}
stats := lb.clientStats.toClientStats()
t := time.Now()
stats.Timestamp = &timestamppb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
if err := s.Send(&lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
ClientStats: stats,
},
}); err != nil {
return
}
}
}
func (lb *lbBalancer) callRemoteBalancer() (backoff bool, _ error) {
lbClient := &loadBalancerClient{cc: lb.ccRemoteLB}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, grpc.FailFast(false))
if err != nil {
return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
}
// grpclb handshake on the stream.
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: &lbpb.InitialLoadBalanceRequest{
Name: lb.target,
},
},
}
if err := stream.Send(initReq); err != nil {
return true, fmt.Errorf("grpclb: failed to send init request: %v", err)
}
reply, err := stream.Recv()
if err != nil {
return true, fmt.Errorf("grpclb: failed to recv init response: %v", err)
}
initResp := reply.GetInitialResponse()
if initResp == nil {
return true, fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
}
if initResp.LoadBalancerDelegate != "" {
return true, fmt.Errorf("grpclb: Delegation is not supported")
}
go func() {
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
lb.sendLoadReport(stream, d)
}
}()
// No backoff if init req/resp handshake was successful.
return false, lb.readServerList(stream)
}
func (lb *lbBalancer) watchRemoteBalancer() {
var retryCount int
for {
doBackoff, err := lb.callRemoteBalancer()
select {
case <-lb.doneCh:
return
default:
if err != nil {
if err == errServerTerminatedConnection {
grpclog.Info(err)
} else {
grpclog.Warning(err)
}
}
}
if !doBackoff {
retryCount = 0
continue
}
timer := time.NewTimer(lb.backoff.Backoff(retryCount))
select {
case <-timer.C:
case <-lb.doneCh:
timer.Stop()
return
}
retryCount++
}
}
func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
var dopts []grpc.DialOption
if creds := lb.opt.DialCreds; creds != nil {
if err := creds.OverrideServerName(remoteLBName); err == nil {
dopts = append(dopts, grpc.WithTransportCredentials(creds))
} else {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err)
dopts = append(dopts, grpc.WithInsecure())
}
} else if bundle := lb.grpclbClientConnCreds; bundle != nil {
dopts = append(dopts, grpc.WithCredentialsBundle(bundle))
} else {
dopts = append(dopts, grpc.WithInsecure())
}
if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a
// special DialOption here.
wcd := internal.WithContextDialer.(func(func(context.Context, string) (net.Conn, error)) grpc.DialOption)
dopts = append(dopts, wcd(lb.opt.Dialer))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, grpc.WithBalancerName(grpc.PickFirstBalancerName))
wrb := internal.WithResolverBuilder.(func(resolver.Builder) grpc.DialOption)
dopts = append(dopts, wrb(lb.manualResolver))
if channelz.IsOn() {
dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParentID))
}
// DialContext using manualResolver.Scheme, which is a random scheme
// generated when init grpclb. The target scheme here is not important.
//
// The grpc dial target will be used by the creds (ALTS) as the authority,
// so it has to be set to remoteLBName that comes from resolver.
cc, err := grpc.DialContext(context.Background(), remoteLBName, dopts...)
if err != nil {
grpclog.Fatalf("failed to dial: %v", err)
}
lb.ccRemoteLB = cc
go lb.watchRemoteBalancer()
}

View File

@ -1,970 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
durationpb "github.com/golang/protobuf/ptypes/duration"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
var (
lbServerName = "bar.com"
beServerName = "foo.com"
lbToken = "iamatoken"
// Resolver replaces localhost with fakeName in Next().
// Dialer replaces fakeName with localhost when dialing.
// This will test that custom dialer is passed from Dial to grpclb.
fakeName = "fake.Name"
)
type serverNameCheckCreds struct {
mu sync.Mutex
sn string
expected string
}
func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if _, err := io.WriteString(rawConn, c.sn); err != nil {
fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
return nil, nil, err
}
return rawConn, nil, nil
}
func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
b := make([]byte, len(c.expected))
errCh := make(chan error, 1)
go func() {
_, err := rawConn.Read(b)
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
fmt.Printf("Failed to read the server name from the server %v", err)
return nil, nil, err
}
case <-ctx.Done():
return nil, nil, ctx.Err()
}
if c.expected != string(b) {
fmt.Printf("Read the server name %s want %s", string(b), c.expected)
return nil, nil, errors.New("received unexpected server name")
}
return rawConn, nil, nil
}
func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
c.mu.Lock()
defer c.mu.Unlock()
return credentials.ProtocolInfo{}
}
func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
c.mu.Lock()
defer c.mu.Unlock()
return &serverNameCheckCreds{
expected: c.expected,
}
}
func (c *serverNameCheckCreds) OverrideServerName(s string) error {
c.mu.Lock()
defer c.mu.Unlock()
c.expected = s
return nil
}
// fakeNameDialer replaces fakeName with localhost when dialing.
// This will test that custom dialer is passed from Dial to grpclb.
func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
addr = strings.Replace(addr, fakeName, "localhost", 1)
return net.DialTimeout("tcp", addr, timeout)
}
// merge merges the new client stats into current stats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) merge(cs *lbpb.ClientStats) {
atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
s.mu.Lock()
for _, perToken := range cs.CallsFinishedWithDrop {
s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
}
s.mu.Unlock()
}
func mapsEqual(a, b map[string]int64) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
if v2, ok := b[k]; !ok || v1 != v2 {
return false
}
}
return true
}
func atomicEqual(a, b *int64) bool {
return atomic.LoadInt64(a) == atomic.LoadInt64(b)
}
// equal compares two rpcStats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) equal(o *rpcStats) bool {
if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
return false
}
if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
return false
}
if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
return false
}
if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if !mapsEqual(s.numCallsDropped, o.numCallsDropped) {
return false
}
return true
}
type remoteBalancer struct {
sls chan *lbpb.ServerList
statsDura time.Duration
done chan struct{}
stats *rpcStats
}
func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
}
}
func (b *remoteBalancer) stop() {
close(b.sls)
close(b.done)
}
func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
initReq := req.GetInitialRequest()
if initReq.Name != beServerName {
return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
}
resp := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbpb.InitialLoadBalanceResponse{
ClientStatsReportInterval: &durationpb.Duration{
Seconds: int64(b.statsDura.Seconds()),
Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
},
},
},
}
if err := stream.Send(resp); err != nil {
return err
}
go func() {
for {
var (
req *lbpb.LoadBalanceRequest
err error
)
if req, err = stream.Recv(); err != nil {
return
}
b.stats.merge(req.GetClientStats())
}
}()
for v := range b.sls {
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: v,
},
}
if err := stream.Send(resp); err != nil {
return err
}
}
<-b.done
return nil
}
type testServer struct {
testpb.TestServiceServer
addr string
fallback bool
}
const testmdkey = "testmd"
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Internal, "failed to receive metadata")
}
if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
}
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return nil
}
func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
for _, l := range lis {
creds := &serverNameCheckCreds{
sn: sn,
}
s := grpc.NewServer(grpc.Creds(creds))
testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
}(s, l)
}
return
}
func stopBackends(servers []*grpc.Server) {
for _, s := range servers {
s.Stop()
}
}
type testServers struct {
lbAddr string
ls *remoteBalancer
lb *grpc.Server
beIPs []net.IP
bePorts []int
}
func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
var (
beListeners []net.Listener
ls *remoteBalancer
lb *grpc.Server
beIPs []net.IP
bePorts []int
)
for i := 0; i < numberOfBackends; i++ {
// Start a backend.
beLis, e := net.Listen("tcp", "localhost:0")
if e != nil {
err = fmt.Errorf("Failed to listen %v", err)
return
}
beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
beListeners = append(beListeners, beLis)
}
backends := startBackends(beServerName, false, beListeners...)
// Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
return
}
lbCreds := &serverNameCheckCreds{
sn: lbServerName,
}
lb = grpc.NewServer(grpc.Creds(lbCreds))
ls = newRemoteBalancer(nil)
lbgrpc.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
}()
tss = &testServers{
lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port),
ls: ls,
lb: lb,
beIPs: beIPs,
bePorts: bePorts,
}
cleanup = func() {
defer stopBackends(backends)
defer func() {
ls.stop()
lb.Stop()
}()
}
return
}
func TestGRPCLB(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
// The remote balancer sends response with duplicates to grpclb client.
func TestGRPCLBWeighted(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(2)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
beServers := []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}, {
IpAddress: tss.beIPs[1],
Port: int32(tss.bePorts[1]),
LoadBalanceToken: lbToken,
}}
portsToIndex := make(map[int]int)
for i := range beServers {
portsToIndex[tss.bePorts[i]] = i
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
sequences := []string{"00101", "00011"}
for _, seq := range sequences {
var (
bes []*lbpb.Server
p peer.Peer
result string
)
for _, s := range seq {
bes = append(bes, beServers[s-'0'])
}
tss.ls.sls <- &lbpb.ServerList{Servers: bes}
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
}
// The generated result will be in format of "0010100101".
if !strings.Contains(result, strings.Repeat(seq, 2)) {
t.Errorf("got result sequence %q, want patten %q", result, seq)
}
}
}
func TestDropRequest(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: false,
}, {
Drop: true,
}},
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
// Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set
// to true.
var i int
for i = 0; i < 1000; i++ {
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
break
}
time.Sleep(time.Millisecond)
}
if i >= 1000 {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
select {
case <-ctx.Done():
t.Fatal("timed out", ctx.Err())
default:
}
for _, failfast := range []bool{true, false} {
for i := 0; i < 3; i++ {
// Even RPCs should fail, because the 2st backend has
// DropForLoadBalancing set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); status.Code(err) != codes.Unavailable {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// Odd RPCs should succeed since they choose the non-drop-request
// backend according to the round robin policy.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
}
}
// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
func TestBalancerDisconnects(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
var (
tests []*testServers
lbs []*grpc.Server
)
for i := 0; i < 2; i++ {
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
tests = append(tests, tss)
lbs = append(lbs, tss.lb)
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tests[0].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: tests[1].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
}
lbs[0].Stop()
// Stop balancer[0], balancer[1] should be used by grpclb.
// Check peer address to see if that happened.
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC sent to second backend after 1 second")
}
type customGRPCLBBuilder struct {
balancer.Builder
name string
}
func (b *customGRPCLBBuilder) Name() string {
return b.name
}
const grpclbCustomFallbackName = "grpclb_with_custom_fallback_timeout"
func init() {
balancer.Register(&customGRPCLBBuilder{
Builder: newLBBuilderWithFallbackTimeout(100 * time.Millisecond),
name: grpclbCustomFallbackName,
})
}
func TestFallback(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
// Start a standalone backend.
beLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen %v", err)
}
defer beLis.Close()
standaloneBEs := startBackends(beServerName, true, beLis)
defer stopBackends(standaloneBEs)
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancerName(grpclbCustomFallbackName),
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: "",
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.String() != beLis.Addr().String() {
t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
}
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
}
type failPreRPCCred struct{}
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if strings.Contains(uri[0], failtosendURI) {
return nil, fmt.Errorf("rpc should fail to send")
}
return nil, nil
}
func (failPreRPCCred) RequireTransportSecurity() bool {
return false
}
func checkStats(stats, expected *rpcStats) error {
if !stats.equal(expected) {
return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
}
return nil
}
func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: drop,
}},
}
tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{expected: beServerName}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds),
grpc.WithPerRPCCredentials(failPreRPCCred{}),
grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
runRPCs(cc)
time.Sleep(1 * time.Second)
stats := tss.ls.stats
return stats
}
const (
countRPC = 40
failtosendURI = "failtosend"
dropErrDesc = "request dropped by grpclb"
)
func TestGRPCLBStatsUnarySuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countRPC-1; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countRPC-1; i++ {
cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil)
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
for i := 0; i < countRPC-1; i++ {
stream, err = testC.FullDuplexCall(context.Background())
if err == nil {
// Wait for stream to end if err is nil.
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
}
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.FullDuplexCall(context.Background())
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
for i := 0; i < countRPC-1; i++ {
cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI)
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}
}

View File

@ -1,214 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"fmt"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
// The parent ClientConn should re-resolve when grpclb loses connection to the
// remote balancer. When the ClientConn inside grpclb gets a TransientFailure,
// it calls lbManualResolver.ResolveNow(), which calls parent ClientConn's
// ResolveNow, and eventually results in re-resolve happening in parent
// ClientConn's resolver (DNS for example).
//
// parent
// ClientConn
// +-----------------------------------------------------------------+
// | parent +---------------------------------+ |
// | DNS ClientConn | grpclb | |
// | resolver balancerWrapper | | |
// | + + | grpclb grpclb | |
// | | | | ManualResolver ClientConn | |
// | | | | + + | |
// | | | | | | Transient | |
// | | | | | | Failure | |
// | | | | | <--------- | | |
// | | | <--------------- | ResolveNow | | |
// | | <--------- | ResolveNow | | | | |
// | | ResolveNow | | | | | |
// | | | | | | | |
// | + + | + + | |
// | +---------------------------------+ |
// +-----------------------------------------------------------------+
// lbManualResolver is used by the ClientConn inside grpclb. It's a manual
// resolver with a special ResolveNow() function.
//
// When ResolveNow() is called, it calls ResolveNow() on the parent ClientConn,
// so when grpclb client lose contact with remote balancers, the parent
// ClientConn's resolver will re-resolve.
type lbManualResolver struct {
scheme string
ccr resolver.ClientConn
ccb balancer.ClientConn
}
func (r *lbManualResolver) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) {
r.ccr = cc
return r, nil
}
func (r *lbManualResolver) Scheme() string {
return r.scheme
}
// ResolveNow calls resolveNow on the parent ClientConn.
func (r *lbManualResolver) ResolveNow(o resolver.ResolveNowOption) {
r.ccb.ResolveNow(o)
}
// Close is a noop for Resolver.
func (*lbManualResolver) Close() {}
// NewAddress calls cc.NewAddress.
func (r *lbManualResolver) NewAddress(addrs []resolver.Address) {
r.ccr.NewAddress(addrs)
}
// NewServiceConfig calls cc.NewServiceConfig.
func (r *lbManualResolver) NewServiceConfig(sc string) {
r.ccr.NewServiceConfig(sc)
}
const subConnCacheTime = time.Second * 10
// lbCacheClientConn is a wrapper balancer.ClientConn with a SubConn cache.
// SubConns will be kept in cache for subConnCacheTime before being removed.
//
// Its new and remove methods are updated to do cache first.
type lbCacheClientConn struct {
cc balancer.ClientConn
timeout time.Duration
mu sync.Mutex
// subConnCache only keeps subConns that are being deleted.
subConnCache map[resolver.Address]*subConnCacheEntry
subConnToAddr map[balancer.SubConn]resolver.Address
}
type subConnCacheEntry struct {
sc balancer.SubConn
cancel func()
abortDeleting bool
}
func newLBCacheClientConn(cc balancer.ClientConn) *lbCacheClientConn {
return &lbCacheClientConn{
cc: cc,
timeout: subConnCacheTime,
subConnCache: make(map[resolver.Address]*subConnCacheEntry),
subConnToAddr: make(map[balancer.SubConn]resolver.Address),
}
}
func (ccc *lbCacheClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
if len(addrs) != 1 {
return nil, fmt.Errorf("grpclb calling NewSubConn with addrs of length %v", len(addrs))
}
addrWithoutMD := addrs[0]
addrWithoutMD.Metadata = nil
ccc.mu.Lock()
defer ccc.mu.Unlock()
if entry, ok := ccc.subConnCache[addrWithoutMD]; ok {
// If entry is in subConnCache, the SubConn was being deleted.
// cancel function will never be nil.
entry.cancel()
delete(ccc.subConnCache, addrWithoutMD)
return entry.sc, nil
}
scNew, err := ccc.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
}
ccc.subConnToAddr[scNew] = addrWithoutMD
return scNew, nil
}
func (ccc *lbCacheClientConn) RemoveSubConn(sc balancer.SubConn) {
ccc.mu.Lock()
defer ccc.mu.Unlock()
addr, ok := ccc.subConnToAddr[sc]
if !ok {
return
}
if entry, ok := ccc.subConnCache[addr]; ok {
if entry.sc != sc {
// This could happen if NewSubConn was called multiple times for the
// same address, and those SubConns are all removed. We remove sc
// immediately here.
delete(ccc.subConnToAddr, sc)
ccc.cc.RemoveSubConn(sc)
}
return
}
entry := &subConnCacheEntry{
sc: sc,
}
ccc.subConnCache[addr] = entry
timer := time.AfterFunc(ccc.timeout, func() {
ccc.mu.Lock()
if entry.abortDeleting {
return
}
ccc.cc.RemoveSubConn(sc)
delete(ccc.subConnToAddr, sc)
delete(ccc.subConnCache, addr)
ccc.mu.Unlock()
})
entry.cancel = func() {
if !timer.Stop() {
// If stop was not successful, the timer has fired (this can only
// happen in a race). But the deleting function is blocked on ccc.mu
// because the mutex was held by the caller of this function.
//
// Set abortDeleting to true to abort the deleting function. When
// the lock is released, the deleting function will acquire the
// lock, check the value of abortDeleting and return.
entry.abortDeleting = true
}
}
}
func (ccc *lbCacheClientConn) UpdateBalancerState(s connectivity.State, p balancer.Picker) {
ccc.cc.UpdateBalancerState(s, p)
}
func (ccc *lbCacheClientConn) close() {
ccc.mu.Lock()
// Only cancel all existing timers. There's no need to remove SubConns.
for _, entry := range ccc.subConnCache {
entry.cancel()
}
ccc.mu.Unlock()
}

View File

@ -1,219 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"fmt"
"sync"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
type mockSubConn struct {
balancer.SubConn
}
type mockClientConn struct {
balancer.ClientConn
mu sync.Mutex
subConns map[balancer.SubConn]resolver.Address
}
func newMockClientConn() *mockClientConn {
return &mockClientConn{
subConns: make(map[balancer.SubConn]resolver.Address),
}
}
func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := &mockSubConn{}
mcc.mu.Lock()
defer mcc.mu.Unlock()
mcc.subConns[sc] = addrs[0]
return sc, nil
}
func (mcc *mockClientConn) RemoveSubConn(sc balancer.SubConn) {
mcc.mu.Lock()
defer mcc.mu.Unlock()
delete(mcc.subConns, sc)
}
const testCacheTimeout = 100 * time.Millisecond
func checkMockCC(mcc *mockClientConn, scLen int) error {
mcc.mu.Lock()
defer mcc.mu.Unlock()
if len(mcc.subConns) != scLen {
return fmt.Errorf("mcc = %+v, want len(mcc.subConns) = %v", mcc.subConns, scLen)
}
return nil
}
func checkCacheCC(ccc *lbCacheClientConn, sccLen, sctaLen int) error {
ccc.mu.Lock()
defer ccc.mu.Unlock()
if len(ccc.subConnCache) != sccLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnCache) = %v", ccc.subConnCache, sccLen)
}
if len(ccc.subConnToAddr) != sctaLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnToAddr) = %v", ccc.subConnToAddr, sctaLen)
}
return nil
}
// Test that SubConn won't be immediately removed.
func TestLBCacheClientConnExpire(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
var err error
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}
// Test that NewSubConn with the same address of a SubConn being removed will
// reuse the SubConn and cancel the removing.
func TestLBCacheClientConnReuse(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Recreate the old subconn, this should cancel the deleting process.
sc, _ = ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
var err error
// Should not become empty after 2*timeout.
time.Sleep(2 * testCacheTimeout)
err = checkMockCC(mcc, 1)
if err != nil {
t.Fatal(err)
}
err = checkCacheCC(ccc, 0, 1)
if err != nil {
t.Fatal(err)
}
// Call remove again, will delete after timeout.
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,33 +0,0 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/lb/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/lb/v1/load_balancer.proto > grpc/lb/v1/load_balancer.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/lb/v1/*.proto
popd
rm -f grpc_lb_v1/*.pb.go
cp "$TMP"/grpc/lb/v1/*.pb.go grpc_lb_v1/

View File

@ -1,477 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package roundrobin_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
type testServer struct {
testpb.TestServiceServer
}
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return nil
}
type test struct {
servers []*grpc.Server
addresses []string
}
func (t *test) cleanup() {
for _, s := range t.servers {
s.Stop()
}
}
func startTestServers(count int) (_ *test, err error) {
t := &test{}
defer func() {
if err != nil {
for _, s := range t.servers {
s.Stop()
}
}
}()
for i := 0; i < count; i++ {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("Failed to listen %v", err)
}
s := grpc.NewServer()
testpb.RegisterTestServiceServer(s, &testServer{})
t.servers = append(t.servers, s)
t.addresses = append(t.addresses, lis.Addr().String())
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
}(s, lis)
}
return t, nil
}
func TestOneBackend(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
}
func TestBackendsRoundRobin(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 5
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
}
func TestAddressesRemoved(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
r.NewAddress([]resolver.Address{})
for i := 0; i < 1000; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC failed after removing all addresses, want RPC to fail with DeadlineExceeded")
}
func TestCloseWithPendingRPC(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
testc := testpb.NewTestServiceClient(cc)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// This RPC blocks until cc is closed.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded {
t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
}
cancel()
}()
}
cc.Close()
wg.Wait()
}
func TestNewAddressWhileBlocking(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
}
r.NewAddress([]resolver.Address{})
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// This RPC blocks until NewAddress is called.
testc.EmptyCall(context.Background(), &testpb.Empty{})
}()
}
time.Sleep(50 * time.Millisecond)
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
wg.Wait()
}
func TestOneServerDown(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 3
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
// Stop one server, RPCs should roundrobin among the remaining servers.
backendCount--
test.servers[backendCount].Stop()
// Loop until see server[backendCount-1] twice without seeing server[backendCount].
var targetSeen int
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
targetSeen = 0
t.Logf("EmptyCall() = _, %v, want _, <nil>", err)
// Due to a race, this RPC could possibly get the connection that
// was closing, and this RPC may fail. Keep trying when this
// happens.
continue
}
switch p.Addr.String() {
case test.addresses[backendCount-1]:
targetSeen++
case test.addresses[backendCount]:
// Reset targetSeen if peer is server[backendCount].
targetSeen = 0
}
// Break to make sure the last picked address is server[-1], so the following for loop won't be flaky.
if targetSeen >= 2 {
break
}
}
if targetSeen != 2 {
t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
}
func TestAllServersDown(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 3
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
// All servers are stopped, failfast RPC should fail with unavailable.
for i := 0; i < backendCount; i++ {
test.servers[i].Stop()
}
time.Sleep(100 * time.Millisecond)
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped")
}

View File

@ -1,469 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"fmt"
"math"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
var _ balancer.Builder = &magicalLB{}
var _ balancer.Balancer = &magicalLB{}
// magicalLB is a ringer for grpclb. It is used to avoid circular dependencies on the grpclb package
type magicalLB struct{}
func (b *magicalLB) Name() string {
return "grpclb"
}
func (b *magicalLB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return b
}
func (b *magicalLB) HandleSubConnStateChange(balancer.SubConn, connectivity.State) {}
func (b *magicalLB) HandleResolvedAddrs([]resolver.Address, error) {}
func (b *magicalLB) Close() {}
func init() {
balancer.Register(&magicalLB{})
}
func checkPickFirst(cc *ClientConn, servers []*server) error {
var (
req = "port"
reply string
err error
)
connected := false
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if connected {
// connected is set to false if peer is not server[0]. So if
// connected is true here, this is the second time we saw
// server[0] in a row. Break because pickfirst is in effect.
break
}
connected = true
} else {
connected = false
}
time.Sleep(time.Millisecond)
}
if !connected {
return fmt.Errorf("pickfirst is not in effect after 5 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port)
}
// The following RPCs should all succeed with the first server.
for i := 0; i < 3; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[0].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err)
}
}
return nil
}
func checkRoundRobin(cc *ClientConn, servers []*server) error {
var (
req = "port"
reply string
err error
)
// Make sure connections to all servers are up.
for i := 0; i < 2; i++ {
// Do this check twice, otherwise the first RPC's transport may still be
// picked by the closing pickfirst balancer, and the test becomes flaky.
for _, s := range servers {
var up bool
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == s.port {
up = true
break
}
time.Sleep(time.Millisecond)
}
if !up {
return fmt.Errorf("server %v is not up within 5 second", s.port)
}
}
}
serverCount := len(servers)
for i := 0; i < 3*serverCount; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[i%serverCount].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err)
}
}
return nil
}
func TestSwitchBalancer(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
numServers := 2
servers, _, scleanup := startServers(t, numServers, math.MaxInt32)
defer scleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}})
// The default balancer is pickfirst.
if err := checkPickFirst(cc, servers); err != nil {
t.Fatalf("check pickfirst returned non-nil error: %v", err)
}
// Switch to roundrobin.
cc.handleServiceConfig(`{"loadBalancingPolicy": "round_robin"}`)
if err := checkRoundRobin(cc, servers); err != nil {
t.Fatalf("check roundrobin returned non-nil error: %v", err)
}
// Switch to pickfirst.
cc.handleServiceConfig(`{"loadBalancingPolicy": "pick_first"}`)
if err := checkPickFirst(cc, servers); err != nil {
t.Fatalf("check pickfirst returned non-nil error: %v", err)
}
}
// Test that balancer specified by dial option will not be overridden.
func TestBalancerDialOption(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
numServers := 2
servers, _, scleanup := startServers(t, numServers, math.MaxInt32)
defer scleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}), WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}})
// The init balancer is roundrobin.
if err := checkRoundRobin(cc, servers); err != nil {
t.Fatalf("check roundrobin returned non-nil error: %v", err)
}
// Switch to pickfirst.
cc.handleServiceConfig(`{"loadBalancingPolicy": "pick_first"}`)
// Balancer is still roundrobin.
if err := checkRoundRobin(cc, servers); err != nil {
t.Fatalf("check roundrobin returned non-nil error: %v", err)
}
}
// First addr update contains grpclb.
func TestSwitchBalancerGRPCLBFirst(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
// ClientConn will switch balancer to grpclb when receives an address of
// type GRPCLB.
r.NewAddress([]resolver.Address{{Addr: "backend"}, {Addr: "grpclb", Type: resolver.GRPCLB}})
var isGRPCLB bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName)
}
// New update containing new backend and new grpclb. Should not switch
// balancer.
r.NewAddress([]resolver.Address{{Addr: "backend2"}, {Addr: "grpclb2", Type: resolver.GRPCLB}})
for i := 0; i < 200; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if !isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("within 200 ms, cc.balancer switched to !grpclb, want grpclb")
}
var isPickFirst bool
// Switch balancer to pickfirst.
r.NewAddress([]resolver.Address{{Addr: "backend"}})
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isPickFirst = cc.curBalancerName == PickFirstBalancerName
cc.mu.Unlock()
if isPickFirst {
break
}
time.Sleep(time.Millisecond)
}
if !isPickFirst {
t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName)
}
}
// First addr update does not contain grpclb.
func TestSwitchBalancerGRPCLBSecond(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: "backend"}})
var isPickFirst bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isPickFirst = cc.curBalancerName == PickFirstBalancerName
cc.mu.Unlock()
if isPickFirst {
break
}
time.Sleep(time.Millisecond)
}
if !isPickFirst {
t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName)
}
// ClientConn will switch balancer to grpclb when receives an address of
// type GRPCLB.
r.NewAddress([]resolver.Address{{Addr: "backend"}, {Addr: "grpclb", Type: resolver.GRPCLB}})
var isGRPCLB bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName)
}
// New update containing new backend and new grpclb. Should not switch
// balancer.
r.NewAddress([]resolver.Address{{Addr: "backend2"}, {Addr: "grpclb2", Type: resolver.GRPCLB}})
for i := 0; i < 200; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if !isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("within 200 ms, cc.balancer switched to !grpclb, want grpclb")
}
// Switch balancer back.
r.NewAddress([]resolver.Address{{Addr: "backend"}})
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isPickFirst = cc.curBalancerName == PickFirstBalancerName
cc.mu.Unlock()
if isPickFirst {
break
}
time.Sleep(time.Millisecond)
}
if !isPickFirst {
t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName)
}
}
// Test that if the current balancer is roundrobin, after switching to grpclb,
// when the resolved address doesn't contain grpclb addresses, balancer will be
// switched back to roundrobin.
func TestSwitchBalancerGRPCLBRoundRobin(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`)
r.NewAddress([]resolver.Address{{Addr: "backend"}})
var isRoundRobin bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isRoundRobin = cc.curBalancerName == "round_robin"
cc.mu.Unlock()
if isRoundRobin {
break
}
time.Sleep(time.Millisecond)
}
if !isRoundRobin {
t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName)
}
// ClientConn will switch balancer to grpclb when receives an address of
// type GRPCLB.
r.NewAddress([]resolver.Address{{Addr: "grpclb", Type: resolver.GRPCLB}})
var isGRPCLB bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName)
}
// Switch balancer back.
r.NewAddress([]resolver.Address{{Addr: "backend"}})
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isRoundRobin = cc.curBalancerName == "round_robin"
cc.mu.Unlock()
if isRoundRobin {
break
}
time.Sleep(time.Millisecond)
}
if !isRoundRobin {
t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName)
}
}
// Test that if resolved address list contains grpclb, the balancer option in
// service config won't take effect. But when there's no grpclb address in a new
// resolved address list, balancer will be switched to the new one.
func TestSwitchBalancerGRPCLBServiceConfig(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: "backend"}})
var isPickFirst bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isPickFirst = cc.curBalancerName == PickFirstBalancerName
cc.mu.Unlock()
if isPickFirst {
break
}
time.Sleep(time.Millisecond)
}
if !isPickFirst {
t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName)
}
// ClientConn will switch balancer to grpclb when receives an address of
// type GRPCLB.
r.NewAddress([]resolver.Address{{Addr: "grpclb", Type: resolver.GRPCLB}})
var isGRPCLB bool
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isGRPCLB = cc.curBalancerName == "grpclb"
cc.mu.Unlock()
if isGRPCLB {
break
}
time.Sleep(time.Millisecond)
}
if !isGRPCLB {
t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName)
}
r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`)
var isRoundRobin bool
for i := 0; i < 200; i++ {
cc.mu.Lock()
isRoundRobin = cc.curBalancerName == "round_robin"
cc.mu.Unlock()
if isRoundRobin {
break
}
time.Sleep(time.Millisecond)
}
// Balancer should NOT switch to round_robin because resolved list contains
// grpclb.
if isRoundRobin {
t.Fatalf("within 200 ms, cc.balancer switched to round_robin, want grpclb")
}
// Switch balancer back.
r.NewAddress([]resolver.Address{{Addr: "backend"}})
for i := 0; i < 5000; i++ {
cc.mu.Lock()
isRoundRobin = cc.curBalancerName == "round_robin"
cc.mu.Unlock()
if isRoundRobin {
break
}
time.Sleep(time.Millisecond)
}
if !isRoundRobin {
t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName)
}
}

View File

@ -1,808 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"fmt"
"math"
"strconv"
"sync"
"testing"
"time"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/status"
// V1 balancer tests use passthrough resolver instead of dns.
// TODO(bar) remove this when removing v1 balaner entirely.
_ "google.golang.org/grpc/resolver/passthrough"
)
func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}}
}
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notify update injector that the update reading is done
readDone chan int
}
func (w *testWatcher) Next() (updates []*naming.Update, err error) {
n := <-w.side
if n == 0 {
return nil, fmt.Errorf("w.side is closed")
}
for i := 0; i < n; i++ {
u := <-w.update
if u != nil {
updates = append(updates, u)
}
}
w.readDone <- 0
return
}
func (w *testWatcher) Close() {
close(w.side)
}
// Inject naming resolution updates to the testWatcher.
func (w *testWatcher) inject(updates []*naming.Update) {
w.side <- len(updates)
for _, u := range updates {
w.update <- u
}
<-w.readDone
}
type testNameResolver struct {
w *testWatcher
addr string
}
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
r.w = &testWatcher{
update: make(chan *naming.Update, 1),
side: make(chan int, 1),
readDone: make(chan int),
}
r.w.side <- 1
r.w.update <- &naming.Update{
Op: naming.Add,
Addr: r.addr,
}
go func() {
<-r.w.readDone
}()
return r.w, nil
}
func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver, func()) {
var servers []*server
for i := 0; i < numServers; i++ {
s := newTestServer()
servers = append(servers, s)
go s.start(t, 0, maxStreams)
s.wait(t, 2*time.Second)
}
// Point to server[0]
addr := "localhost:" + servers[0].port
return servers, &testNameResolver{
addr: addr,
}, func() {
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}
}
func TestNameDiscovery(t *testing.T) {
defer leakcheck.Check(t)
// Start 2 servers on 2 ports.
numServers := 2
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
req := "port"
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Inject the name resolution change to remove servers[0] and add servers[1].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
})
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
})
r.w.inject(updates)
// Loop until the rpcs in flight talks to servers[1].
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
}
func TestEmptyAddrs(t *testing.T) {
defer leakcheck.Check(t)
servers, r, cleanup := startServers(t, 1, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
// available after that.
u := &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
cancel()
}
}
func TestRoundRobin(t *testing.T) {
defer leakcheck.Check(t)
// Start 3 servers on 3 ports.
numServers := 3
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Add server2[2] to the service discovery.
u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[i%numServers].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
}
}
}
func TestCloseWithPendingRPC(t *testing.T) {
defer leakcheck.Check(t)
servers, r, cleanup := startServers(t, 1, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}}
r.w.inject(updates)
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
time.Sleep(10 * time.Millisecond)
cancel()
}
// Issue 2 RPCs which should be completed with error status once cc is closed.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
go func() {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
time.Sleep(5 * time.Millisecond)
cc.Close()
wg.Wait()
}
func TestGetOnWaitChannel(t *testing.T) {
defer leakcheck.Check(t)
servers, r, cleanup := startServers(t, 1, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}}
r.w.inject(updates)
for {
var reply string
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
cancel()
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}()
// Add a connected server to get the above RPC through.
updates = []*naming.Update{{
Op: naming.Add,
Addr: "localhost:" + servers[0].port,
}}
r.w.inject(updates)
// Wait until the above RPC succeeds.
wg.Wait()
}
func TestOneServerDown(t *testing.T) {
defer leakcheck.Check(t)
// Start 2 servers.
numServers := 2
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake())
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
})
r.w.inject(updates)
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, kill server[0].
servers[0].stop()
wg.Done()
}()
// All non-failfast RPCs should not block because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is killed around the same time to make it racy between balancer and gRPC internals.
cc.Invoke(context.Background(), "/foo/bar", &req, &reply, FailFast(false))
wg.Done()
}()
}
wg.Wait()
}
func TestOneAddressRemoval(t *testing.T) {
defer leakcheck.Check(t)
// Start 2 servers.
numServers := 2
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
})
r.w.inject(updates)
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, delete server[0].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
})
r.w.inject(updates)
wg.Done()
}()
// All non-failfast RPCs should not fail because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
var reply string
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()
}
wg.Wait()
}
func checkServerUp(t *testing.T, currentServer *server) {
req := "port"
port := currentServer.port
cc, err := Dial("passthrough:///localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
var reply string
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == port {
break
}
time.Sleep(10 * time.Millisecond)
}
}
func TestPickFirstEmptyAddrs(t *testing.T) {
defer leakcheck.Check(t)
servers, r, cleanup := startServers(t, 1, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
// available after that.
u := &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
cancel()
}
}
func TestPickFirstCloseWithPendingRPC(t *testing.T) {
defer leakcheck.Check(t)
servers, r, cleanup := startServers(t, 1, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}}
r.w.inject(updates)
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
time.Sleep(10 * time.Millisecond)
cancel()
}
// Issue 2 RPCs which should be completed with error status once cc is closed.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
go func() {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
time.Sleep(5 * time.Millisecond)
cc.Close()
wg.Wait()
}
func TestPickFirstOrderAllServerUp(t *testing.T) {
defer leakcheck.Check(t)
// Start 3 servers on 3 ports.
numServers := 3
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] and [2] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until all 3 servers are up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])
checkServerUp(t, servers[2])
// Check the incoming RPCs served in server[0]
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}
// Delete server[0] in the balancer, the incoming RPCs served in server[1]
// For test addrconn, close server[0] instead
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until it changes to server[1]
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}
// Add server[0] back to the balancer, the incoming RPCs served in server[1]
// Add is append operation, the order of Notify now is {server[1].port server[2].port server[0].port}
u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}
// Delete server[1] in the balancer, the incoming RPCs served in server[2]
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[2].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port)
}
time.Sleep(10 * time.Millisecond)
}
// Delete server[2] in the balancer, the incoming RPCs served in server[0]
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}
}
func TestPickFirstOrderOneServerDown(t *testing.T) {
defer leakcheck.Check(t)
// Start 3 servers on 3 ports.
numServers := 3
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake())
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] and [2] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until all 3 servers are up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])
checkServerUp(t, servers[2])
// Check the incoming RPCs served in server[0]
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}
// server[0] down, incoming RPCs served in server[1], but the order of Notify still remains
// {server[0] server[1] server[2]}
servers[0].stop()
// Loop until it changes to server[1]
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}
// up the server[0] back, the incoming RPCs served in server[1]
p, _ := strconv.Atoi(servers[0].port)
servers[0] = newTestServer()
go servers[0].start(t, p, math.MaxUint32)
defer servers[0].stop()
servers[0].wait(t, 2*time.Second)
checkServerUp(t, servers[0])
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}
// Delete server[1] in the balancer, the incoming RPCs served in server[0]
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
for {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}
}
func TestPickFirstOneAddressRemoval(t *testing.T) {
defer leakcheck.Check(t)
// Start 2 servers.
numServers := 2
servers, r, cleanup := startServers(t, numServers, math.MaxUint32)
defer cleanup()
cc, err := Dial("passthrough:///localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
defer cc.Close()
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
})
r.w.inject(updates)
// Create a new cc to Loop until servers[1] is up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])
var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, delete server[0].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
})
r.w.inject(updates)
wg.Done()
}()
// All non-failfast RPCs should not fail because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
var reply string
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@ -1,547 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package main provides benchmark with setting flags.
An example to run some benchmarks with profiling enabled:
go run benchmark/benchmain/main.go -benchtime=10s -workloads=all \
-compression=on -maxConcurrentCalls=1 -trace=off \
-reqSizeBytes=1,1048576 -respSizeBytes=1,1048576 -networkMode=Local \
-cpuProfile=cpuProf -memProfile=memProf -memProfileRate=10000 -resultFile=result
As a suggestion, when creating a branch, you can run this benchmark and save the result
file "-resultFile=basePerf", and later when you at the middle of the work or finish the
work, you can get the benchmark result and compare it with the base anytime.
Assume there are two result files names as "basePerf" and "curPerf" created by adding
-resultFile=basePerf and -resultFile=curPerf.
To format the curPerf, run:
go run benchmark/benchresult/main.go curPerf
To observe how the performance changes based on a base result, run:
go run benchmark/benchresult/main.go basePerf curPerf
*/
package main
import (
"context"
"encoding/gob"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"google.golang.org/grpc"
bm "google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/test/bufconn"
)
const (
modeOn = "on"
modeOff = "off"
modeBoth = "both"
)
var allCompressionModes = []string{modeOn, modeOff, modeBoth}
var allTraceModes = []string{modeOn, modeOff, modeBoth}
const (
workloadsUnary = "unary"
workloadsStreaming = "streaming"
workloadsAll = "all"
)
var allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsAll}
var (
runMode = []bool{true, true} // {runUnary, runStream}
// When set the latency to 0 (no delay), the result is slower than the real result with no delay
// because latency simulation section has extra operations
ltc = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
kbps = []int{0, 10240} // if non-positive, infinite
mtu = []int{0} // if non-positive, infinite
maxConcurrentCalls = []int{1, 8, 64, 512}
reqSizeBytes = []int{1, 1024, 1024 * 1024}
respSizeBytes = []int{1, 1024, 1024 * 1024}
enableTrace []bool
benchtime time.Duration
memProfile, cpuProfile string
memProfileRate int
enableCompressor []bool
enableChannelz []bool
networkMode string
benchmarkResultFile string
networks = map[string]latency.Network{
"Local": latency.Local,
"LAN": latency.LAN,
"WAN": latency.WAN,
"Longhaul": latency.Longhaul,
}
)
func unaryBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, cleanup := makeFuncUnary(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
func streamBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, cleanup := makeFuncStream(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
func makeFuncUnary(benchFeatures stats.Features) (func(int), func()) {
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
if benchFeatures.EnableCompressor {
sopts = append(sopts,
grpc.RPCCompressor(nopCompressor{}),
grpc.RPCDecompressor(nopDecompressor{}),
)
opts = append(opts,
grpc.WithCompressor(nopCompressor{}),
grpc.WithDecompressor(nopDecompressor{}),
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithInsecure())
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
return func(int) {
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}, func() {
conn.Close()
stopper()
}
}
func makeFuncStream(benchFeatures stats.Features) (func(int), func()) {
// TODO: Refactor to remove duplication with makeFuncUnary.
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
if benchFeatures.EnableCompressor {
sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
opts = append(opts,
grpc.WithCompressor(grpc.NewGZIPCompressor()),
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithInsecure())
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
streams := make([]testpb.BenchmarkService_StreamingCallClient, benchFeatures.MaxConcurrentCalls)
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[i] = stream
}
return func(pos int) {
streamCaller(streams[pos], benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}, func() {
conn.Close()
stopper()
}
}
func unaryCaller(client testpb.BenchmarkServiceClient, reqSize, respSize int) {
if err := bm.DoUnaryCall(client, reqSize, respSize); err != nil {
grpclog.Fatalf("DoUnaryCall failed: %v", err)
}
}
func streamCaller(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := bm.DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
grpclog.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}
func runBenchmark(caller func(int), startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
// Warm up connection.
for i := 0; i < 10; i++ {
caller(0)
}
// Run benchmark.
startTimer()
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
bmEnd := time.Now().Add(benchtime)
var count int32
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
go func(pos int) {
for {
t := time.Now()
if t.After(bmEnd) {
break
}
start := time.Now()
caller(pos)
elapse := time.Since(start)
atomic.AddInt32(&count, 1)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}(i)
}
wg.Wait()
stopTimer(count)
}
var useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead of system network I/O")
// Initiate main function to get settings of features.
func init() {
var (
workloads, traceMode, compressorMode, readLatency, channelzOn string
readKbps, readMtu, readMaxConcurrentCalls intSliceType
readReqSizeBytes, readRespSizeBytes intSliceType
)
flag.StringVar(&workloads, "workloads", workloadsAll,
fmt.Sprintf("Workloads to execute - One of: %v", strings.Join(allWorkloads, ", ")))
flag.StringVar(&traceMode, "trace", modeOff,
fmt.Sprintf("Trace mode - One of: %v", strings.Join(allTraceModes, ", ")))
flag.StringVar(&readLatency, "latency", "", "Simulated one-way network latency - may be a comma-separated list")
flag.StringVar(&channelzOn, "channelz", modeOff, "whether channelz should be turned on")
flag.DurationVar(&benchtime, "benchtime", time.Second, "Configures the amount of time to run each benchmark")
flag.Var(&readKbps, "kbps", "Simulated network throughput (in kbps) - may be a comma-separated list")
flag.Var(&readMtu, "mtu", "Simulated network MTU (Maximum Transmission Unit) - may be a comma-separated list")
flag.Var(&readMaxConcurrentCalls, "maxConcurrentCalls", "Number of concurrent RPCs during benchmarks")
flag.Var(&readReqSizeBytes, "reqSizeBytes", "Request size in bytes - may be a comma-separated list")
flag.Var(&readRespSizeBytes, "respSizeBytes", "Response size in bytes - may be a comma-separated list")
flag.StringVar(&memProfile, "memProfile", "", "Enables memory profiling output to the filename provided.")
flag.IntVar(&memProfileRate, "memProfileRate", 512*1024, "Configures the memory profiling rate. \n"+
"memProfile should be set before setting profile rate. To include every allocated block in the profile, "+
"set MemProfileRate to 1. To turn off profiling entirely, set MemProfileRate to 0. 512 * 1024 by default.")
flag.StringVar(&cpuProfile, "cpuProfile", "", "Enables CPU profiling output to the filename provided")
flag.StringVar(&compressorMode, "compression", modeOff,
fmt.Sprintf("Compression mode - One of: %v", strings.Join(allCompressionModes, ", ")))
flag.StringVar(&benchmarkResultFile, "resultFile", "", "Save the benchmark result into a binary file")
flag.StringVar(&networkMode, "networkMode", "", "Network mode includes LAN, WAN, Local and Longhaul")
flag.Parse()
if flag.NArg() != 0 {
log.Fatal("Error: unparsed arguments: ", flag.Args())
}
switch workloads {
case workloadsUnary:
runMode[0] = true
runMode[1] = false
case workloadsStreaming:
runMode[0] = false
runMode[1] = true
case workloadsAll:
runMode[0] = true
runMode[1] = true
default:
log.Fatalf("Unknown workloads setting: %v (want one of: %v)",
workloads, strings.Join(allWorkloads, ", "))
}
enableCompressor = setMode(compressorMode)
enableTrace = setMode(traceMode)
enableChannelz = setMode(channelzOn)
// Time input formats as (time + unit).
readTimeFromInput(&ltc, readLatency)
readIntFromIntSlice(&kbps, readKbps)
readIntFromIntSlice(&mtu, readMtu)
readIntFromIntSlice(&maxConcurrentCalls, readMaxConcurrentCalls)
readIntFromIntSlice(&reqSizeBytes, readReqSizeBytes)
readIntFromIntSlice(&respSizeBytes, readRespSizeBytes)
// Re-write latency, kpbs and mtu if network mode is set.
if network, ok := networks[networkMode]; ok {
ltc = []time.Duration{network.Latency}
kbps = []int{network.Kbps}
mtu = []int{network.MTU}
}
}
func setMode(name string) []bool {
switch name {
case modeOn:
return []bool{true}
case modeOff:
return []bool{false}
case modeBoth:
return []bool{false, true}
default:
log.Fatalf("Unknown %s setting: %v (want one of: %v)",
name, name, strings.Join(allCompressionModes, ", "))
return []bool{}
}
}
type intSliceType []int
func (intSlice *intSliceType) String() string {
return fmt.Sprintf("%v", *intSlice)
}
func (intSlice *intSliceType) Set(value string) error {
if len(*intSlice) > 0 {
return errors.New("interval flag already set")
}
for _, num := range strings.Split(value, ",") {
next, err := strconv.Atoi(num)
if err != nil {
return err
}
*intSlice = append(*intSlice, next)
}
return nil
}
func readIntFromIntSlice(values *[]int, replace intSliceType) {
// If not set replace in the flag, just return to run the default settings.
if len(replace) == 0 {
return
}
*values = replace
}
func readTimeFromInput(values *[]time.Duration, replace string) {
if strings.Compare(replace, "") != 0 {
*values = []time.Duration{}
for _, ltc := range strings.Split(replace, ",") {
duration, err := time.ParseDuration(ltc)
if err != nil {
log.Fatal(err.Error())
}
*values = append(*values, duration)
}
}
}
func main() {
before()
featuresPos := make([]int, 9)
// 0:enableTracing 1:ltc 2:kbps 3:mtu 4:maxC 5:reqSize 6:respSize
featuresNum := []int{len(enableTrace), len(ltc), len(kbps), len(mtu),
len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor), len(enableChannelz)}
initalPos := make([]int, len(featuresPos))
s := stats.NewStats(10)
s.SortLatency()
var memStats runtime.MemStats
var results testing.BenchmarkResult
var startAllocs, startBytes uint64
var startTime time.Time
start := true
var startTimer = func() {
runtime.ReadMemStats(&memStats)
startAllocs = memStats.Mallocs
startBytes = memStats.TotalAlloc
startTime = time.Now()
}
var stopTimer = func(count int32) {
runtime.ReadMemStats(&memStats)
results = testing.BenchmarkResult{N: int(count), T: time.Since(startTime),
Bytes: 0, MemAllocs: memStats.Mallocs - startAllocs, MemBytes: memStats.TotalAlloc - startBytes}
}
sharedPos := make([]bool, len(featuresPos))
for i := 0; i < len(featuresPos); i++ {
if featuresNum[i] <= 1 {
sharedPos[i] = true
}
}
// Run benchmarks
resultSlice := []stats.BenchResults{}
for !reflect.DeepEqual(featuresPos, initalPos) || start {
start = false
benchFeature := stats.Features{
NetworkMode: networkMode,
EnableTrace: enableTrace[featuresPos[0]],
Latency: ltc[featuresPos[1]],
Kbps: kbps[featuresPos[2]],
Mtu: mtu[featuresPos[3]],
MaxConcurrentCalls: maxConcurrentCalls[featuresPos[4]],
ReqSizeBytes: reqSizeBytes[featuresPos[5]],
RespSizeBytes: respSizeBytes[featuresPos[6]],
EnableCompressor: enableCompressor[featuresPos[7]],
EnableChannelz: enableChannelz[featuresPos[8]],
}
grpc.EnableTracing = enableTrace[featuresPos[0]]
if enableChannelz[featuresPos[8]] {
channelz.TurnOn()
}
if runMode[0] {
unaryBenchmark(startTimer, stopTimer, benchFeature, benchtime, s)
s.SetBenchmarkResult("Unary", benchFeature, results.N,
results.AllocedBytesPerOp(), results.AllocsPerOp(), sharedPos)
fmt.Println(s.BenchString())
fmt.Println(s.String())
resultSlice = append(resultSlice, s.GetBenchmarkResults())
s.Clear()
}
if runMode[1] {
streamBenchmark(startTimer, stopTimer, benchFeature, benchtime, s)
s.SetBenchmarkResult("Stream", benchFeature, results.N,
results.AllocedBytesPerOp(), results.AllocsPerOp(), sharedPos)
fmt.Println(s.BenchString())
fmt.Println(s.String())
resultSlice = append(resultSlice, s.GetBenchmarkResults())
s.Clear()
}
bm.AddOne(featuresPos, featuresNum)
}
after(resultSlice)
}
func before() {
if memProfile != "" {
runtime.MemProfileRate = memProfileRate
}
if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
return
}
if err := pprof.StartCPUProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't start cpu profile: %s\n", err)
f.Close()
return
}
}
}
func after(data []stats.BenchResults) {
if cpuProfile != "" {
pprof.StopCPUProfile() // flushes profile to disk
}
if memProfile != "" {
f, err := os.Create(memProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
runtime.GC() // materialize all statistics
if err = pprof.WriteHeapProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write heap profile %s: %s\n", memProfile, err)
os.Exit(2)
}
f.Close()
}
if benchmarkResultFile != "" {
f, err := os.Create(benchmarkResultFile)
if err != nil {
log.Fatalf("testing: can't write benchmark result %s: %s\n", benchmarkResultFile, err)
}
dataEncoder := gob.NewEncoder(f)
dataEncoder.Encode(data)
f.Close()
}
}
// nopCompressor is a compressor that just copies data.
type nopCompressor struct{}
func (nopCompressor) Do(w io.Writer, p []byte) error {
n, err := w.Write(p)
if err != nil {
return err
}
if n != len(p) {
return fmt.Errorf("nopCompressor.Write: wrote %v bytes; want %v", n, len(p))
}
return nil
}
func (nopCompressor) Type() string { return "nop" }
// nopDecompressor is a decompressor that just copies data.
type nopDecompressor struct{}
func (nopDecompressor) Do(r io.Reader) ([]byte, error) { return ioutil.ReadAll(r) }
func (nopDecompressor) Type() string { return "nop" }

View File

@ -1,369 +0,0 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate protoc -I grpc_testing --go_out=plugins=grpc:grpc_testing grpc_testing/control.proto grpc_testing/messages.proto grpc_testing/payloads.proto grpc_testing/services.proto grpc_testing/stats.proto
/*
Package benchmark implements the building blocks to setup end-to-end gRPC benchmarks.
*/
package benchmark
import (
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"google.golang.org/grpc"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
)
// AddOne add 1 to the features slice
func AddOne(features []int, featuresMaxPosition []int) {
for i := len(features) - 1; i >= 0; i-- {
features[i] = (features[i] + 1)
if features[i]/featuresMaxPosition[i] == 0 {
break
}
features[i] = features[i] % featuresMaxPosition[i]
}
}
// Allows reuse of the same testpb.Payload object.
func setPayload(p *testpb.Payload, t testpb.PayloadType, size int) {
if size < 0 {
grpclog.Fatalf("Requested a response with invalid length %d", size)
}
body := make([]byte, size)
switch t {
case testpb.PayloadType_COMPRESSABLE:
case testpb.PayloadType_UNCOMPRESSABLE:
grpclog.Fatalf("PayloadType UNCOMPRESSABLE is not supported")
default:
grpclog.Fatalf("Unsupported payload type: %d", t)
}
p.Type = t
p.Body = body
}
func newPayload(t testpb.PayloadType, size int) *testpb.Payload {
p := new(testpb.Payload)
setPayload(p, t, size)
return p
}
type testServer struct {
}
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{
Payload: newPayload(in.ResponseType, int(in.ResponseSize)),
}, nil
}
func (s *testServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error {
response := &testpb.SimpleResponse{
Payload: new(testpb.Payload),
}
in := new(testpb.SimpleRequest)
for {
// use ServerStream directly to reuse the same testpb.SimpleRequest object
err := stream.(grpc.ServerStream).RecvMsg(in)
if err == io.EOF {
// read done.
return nil
}
if err != nil {
return err
}
setPayload(response.Payload, in.ResponseType, int(in.ResponseSize))
if err := stream.Send(response); err != nil {
return err
}
}
}
// byteBufServer is a gRPC server that sends and receives byte buffer.
// The purpose is to benchmark the gRPC performance without protobuf serialization/deserialization overhead.
type byteBufServer struct {
respSize int32
}
// UnaryCall is an empty function and is not used for benchmark.
// If bytebuf UnaryCall benchmark is needed later, the function body needs to be updated.
func (s *byteBufServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
}
func (s *byteBufServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error {
for {
var in []byte
err := stream.(grpc.ServerStream).RecvMsg(&in)
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := make([]byte, s.respSize)
if err := stream.(grpc.ServerStream).SendMsg(&out); err != nil {
return err
}
}
}
// ServerInfo contains the information to create a gRPC benchmark server.
type ServerInfo struct {
// Type is the type of the server.
// It should be "protobuf" or "bytebuf".
Type string
// Metadata is an optional configuration.
// For "protobuf", it's ignored.
// For "bytebuf", it should be an int representing response size.
Metadata interface{}
// Listener is the network listener for the server to use
Listener net.Listener
}
// StartServer starts a gRPC server serving a benchmark service according to info.
// It returns a function to stop the server.
func StartServer(info ServerInfo, opts ...grpc.ServerOption) func() {
opts = append(opts, grpc.WriteBufferSize(128*1024))
opts = append(opts, grpc.ReadBufferSize(128*1024))
s := grpc.NewServer(opts...)
switch info.Type {
case "protobuf":
testpb.RegisterBenchmarkServiceServer(s, &testServer{})
case "bytebuf":
respSize, ok := info.Metadata.(int32)
if !ok {
grpclog.Fatalf("failed to StartServer, invalid metadata: %v, for Type: %v", info.Metadata, info.Type)
}
testpb.RegisterBenchmarkServiceServer(s, &byteBufServer{respSize: respSize})
default:
grpclog.Fatalf("failed to StartServer, unknown Type: %v", info.Type)
}
go s.Serve(info.Listener)
return func() {
s.Stop()
}
}
// DoUnaryCall performs an unary RPC with given stub and request and response sizes.
func DoUnaryCall(tc testpb.BenchmarkServiceClient, reqSize, respSize int) error {
pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize)
req := &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSize),
Payload: pl,
}
if _, err := tc.UnaryCall(context.Background(), req); err != nil {
return fmt.Errorf("/BenchmarkService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
}
return nil
}
// DoStreamingRoundTrip performs a round trip for a single streaming rpc.
func DoStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error {
pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize)
req := &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSize),
Payload: pl,
}
if err := stream.Send(req); err != nil {
return fmt.Errorf("/BenchmarkService/StreamingCall.Send(_) = %v, want <nil>", err)
}
if _, err := stream.Recv(); err != nil {
// EOF is a valid error here.
if err == io.EOF {
return nil
}
return fmt.Errorf("/BenchmarkService/StreamingCall.Recv(_) = %v, want <nil>", err)
}
return nil
}
// DoByteBufStreamingRoundTrip performs a round trip for a single streaming rpc, using a custom codec for byte buffer.
func DoByteBufStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error {
out := make([]byte, reqSize)
if err := stream.(grpc.ClientStream).SendMsg(&out); err != nil {
return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).SendMsg(_) = %v, want <nil>", err)
}
var in []byte
if err := stream.(grpc.ClientStream).RecvMsg(&in); err != nil {
// EOF is a valid error here.
if err == io.EOF {
return nil
}
return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).RecvMsg(_) = %v, want <nil>", err)
}
return nil
}
// NewClientConn creates a gRPC client connection to addr.
func NewClientConn(addr string, opts ...grpc.DialOption) *grpc.ClientConn {
return NewClientConnWithContext(context.Background(), addr, opts...)
}
// NewClientConnWithContext creates a gRPC client connection to addr using ctx.
func NewClientConnWithContext(ctx context.Context, addr string, opts ...grpc.DialOption) *grpc.ClientConn {
opts = append(opts, grpc.WithWriteBufferSize(128*1024))
opts = append(opts, grpc.WithReadBufferSize(128*1024))
conn, err := grpc.DialContext(ctx, addr, opts...)
if err != nil {
grpclog.Fatalf("NewClientConn(%q) failed to create a ClientConn %v", addr, err)
}
return conn
}
func runUnary(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}),
)
tc := testpb.NewBenchmarkServiceClient(conn)
// Warm up connection.
for i := 0; i < 10; i++ {
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}
ch := make(chan int, benchFeatures.MaxConcurrentCalls*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
// Distribute the b.N calls over maxConcurrentCalls workers.
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
go func() {
for range ch {
start := time.Now()
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch <- i
}
close(ch)
wg.Wait()
b.StopTimer()
conn.Close()
}
func runStream(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}),
)
tc := testpb.NewBenchmarkServiceClient(conn)
// Warm up connection.
stream, err := tc.StreamingCall(context.Background())
if err != nil {
b.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
for i := 0; i < 10; i++ {
streamCaller(stream, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}
ch := make(chan struct{}, benchFeatures.MaxConcurrentCalls*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
// Distribute the b.N calls over maxConcurrentCalls workers.
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
b.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
go func() {
for range ch {
start := time.Now()
streamCaller(stream, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch <- struct{}{}
}
close(ch)
wg.Wait()
b.StopTimer()
conn.Close()
}
func unaryCaller(client testpb.BenchmarkServiceClient, reqSize, respSize int) {
if err := DoUnaryCall(client, reqSize, respSize); err != nil {
grpclog.Fatalf("DoUnaryCall failed: %v", err)
}
}
func streamCaller(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
grpclog.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}

View File

@ -1,83 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package benchmark
import (
"fmt"
"os"
"reflect"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark/stats"
)
func BenchmarkClient(b *testing.B) {
enableTrace := []bool{true, false} // run both enable and disable by default
// When set the latency to 0 (no delay), the result is slower than the real result with no delay
// because latency simulation section has extra operations
latency := []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
kbps := []int{0, 10240} // if non-positive, infinite
mtu := []int{0} // if non-positive, infinite
maxConcurrentCalls := []int{1, 8, 64, 512}
reqSizeBytes := []int{1, 1024 * 1024}
respSizeBytes := []int{1, 1024 * 1024}
featuresCurPos := make([]int, 7)
// 0:enableTracing 1:md 2:ltc 3:kbps 4:mtu 5:maxC 6:connCount 7:reqSize 8:respSize
featuresMaxPosition := []int{len(enableTrace), len(latency), len(kbps), len(mtu), len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes)}
initalPos := make([]int, len(featuresCurPos))
// run benchmarks
start := true
for !reflect.DeepEqual(featuresCurPos, initalPos) || start {
start = false
tracing := "Trace"
if !enableTrace[featuresCurPos[0]] {
tracing = "noTrace"
}
benchFeature := stats.Features{
EnableTrace: enableTrace[featuresCurPos[0]],
Latency: latency[featuresCurPos[1]],
Kbps: kbps[featuresCurPos[2]],
Mtu: mtu[featuresCurPos[3]],
MaxConcurrentCalls: maxConcurrentCalls[featuresCurPos[4]],
ReqSizeBytes: reqSizeBytes[featuresCurPos[5]],
RespSizeBytes: respSizeBytes[featuresCurPos[6]],
}
grpc.EnableTracing = enableTrace[featuresCurPos[0]]
b.Run(fmt.Sprintf("Unary-%s-%s",
tracing, benchFeature.String()), func(b *testing.B) {
runUnary(b, benchFeature)
})
b.Run(fmt.Sprintf("Stream-%s-%s",
tracing, benchFeature.String()), func(b *testing.B) {
runStream(b, benchFeature)
})
AddOne(featuresCurPos, featuresMaxPosition)
}
}
func TestMain(m *testing.M) {
os.Exit(stats.RunTestMain(m))
}

View File

@ -1,133 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
To format the benchmark result:
go run benchmark/benchresult/main.go resultfile
To see the performance change based on a old result:
go run benchmark/benchresult/main.go resultfile_old resultfile
It will print the comparison result of intersection benchmarks between two files.
*/
package main
import (
"encoding/gob"
"fmt"
"log"
"os"
"strconv"
"strings"
"time"
"google.golang.org/grpc/benchmark/stats"
)
func createMap(fileName string, m map[string]stats.BenchResults) {
f, err := os.Open(fileName)
if err != nil {
log.Fatalf("Read file %s error: %s\n", fileName, err)
}
defer f.Close()
var data []stats.BenchResults
decoder := gob.NewDecoder(f)
if err = decoder.Decode(&data); err != nil {
log.Fatalf("Decode file %s error: %s\n", fileName, err)
}
for _, d := range data {
m[d.RunMode+"-"+d.Features.String()] = d
}
}
func intChange(title string, val1, val2 int64) string {
return fmt.Sprintf("%10s %12s %12s %8.2f%%\n", title, strconv.FormatInt(val1, 10),
strconv.FormatInt(val2, 10), float64(val2-val1)*100/float64(val1))
}
func timeChange(title int, val1, val2 time.Duration) string {
return fmt.Sprintf("%10s %12s %12s %8.2f%%\n", strconv.Itoa(title)+" latency", val1.String(),
val2.String(), float64(val2-val1)*100/float64(val1))
}
func compareTwoMap(m1, m2 map[string]stats.BenchResults) {
for k2, v2 := range m2 {
if v1, ok := m1[k2]; ok {
changes := k2 + "\n"
changes += fmt.Sprintf("%10s %12s %12s %8s\n", "Title", "Before", "After", "Percentage")
changes += intChange("Bytes/op", v1.AllocedBytesPerOp, v2.AllocedBytesPerOp)
changes += intChange("Allocs/op", v1.AllocsPerOp, v2.AllocsPerOp)
changes += timeChange(v1.Latency[1].Percent, v1.Latency[1].Value, v2.Latency[1].Value)
changes += timeChange(v1.Latency[2].Percent, v1.Latency[2].Value, v2.Latency[2].Value)
fmt.Printf("%s\n", changes)
}
}
}
func compareBenchmark(file1, file2 string) {
var BenchValueFile1 map[string]stats.BenchResults
var BenchValueFile2 map[string]stats.BenchResults
BenchValueFile1 = make(map[string]stats.BenchResults)
BenchValueFile2 = make(map[string]stats.BenchResults)
createMap(file1, BenchValueFile1)
createMap(file2, BenchValueFile2)
compareTwoMap(BenchValueFile1, BenchValueFile2)
}
func printline(benchName, ltc50, ltc90, allocByte, allocsOp interface{}) {
fmt.Printf("%-80v%12v%12v%12v%12v\n", benchName, ltc50, ltc90, allocByte, allocsOp)
}
func formatBenchmark(fileName string) {
f, err := os.Open(fileName)
if err != nil {
log.Fatalf("Read file %s error: %s\n", fileName, err)
}
defer f.Close()
var data []stats.BenchResults
decoder := gob.NewDecoder(f)
if err = decoder.Decode(&data); err != nil {
log.Fatalf("Decode file %s error: %s\n", fileName, err)
}
if len(data) == 0 {
log.Fatalf("No data in file %s\n", fileName)
}
printPos := data[0].SharedPosion
fmt.Println("\nShared features:\n" + strings.Repeat("-", 20))
fmt.Print(stats.PartialPrintString(printPos, data[0].Features, true))
fmt.Println(strings.Repeat("-", 35))
for i := 0; i < len(data[0].SharedPosion); i++ {
printPos[i] = !printPos[i]
}
printline("Name", "latency-50", "latency-90", "Alloc (B)", "Alloc (#)")
for _, d := range data {
name := d.RunMode + stats.PartialPrintString(printPos, d.Features, false)
printline(name, d.Latency[1].Value.String(), d.Latency[2].Value.String(),
d.AllocedBytesPerOp, d.AllocsPerOp)
}
}
func main() {
if len(os.Args) == 2 {
formatBenchmark(os.Args[1])
} else {
compareBenchmark(os.Args[1], os.Args[2])
}
}

View File

@ -1,187 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"context"
"flag"
"fmt"
"os"
"runtime"
"runtime/pprof"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var (
port = flag.String("port", "50051", "Localhost port to connect to.")
numRPC = flag.Int("r", 1, "The number of concurrent RPCs on each connection.")
numConn = flag.Int("c", 1, "The number of parallel connections.")
warmupDur = flag.Int("w", 10, "Warm-up duration in seconds")
duration = flag.Int("d", 60, "Benchmark duration in seconds")
rqSize = flag.Int("req", 1, "Request message size in bytes.")
rspSize = flag.Int("resp", 1, "Response message size in bytes.")
rpcType = flag.String("rpc_type", "unary",
`Configure different client rpc type. Valid options are:
unary;
streaming.`)
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
wg sync.WaitGroup
hopts = stats.HistogramOptions{
NumBuckets: 2495,
GrowthFactor: .01,
}
mu sync.Mutex
hists []*stats.Histogram
)
func main() {
flag.Parse()
if *testName == "" {
grpclog.Fatalf("test_name not set")
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseSize: int32(*rspSize),
Payload: &testpb.Payload{
Type: testpb.PayloadType_COMPRESSABLE,
Body: make([]byte, *rqSize),
},
}
connectCtx, connectCancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer connectCancel()
ccs := buildConnections(connectCtx)
warmDeadline := time.Now().Add(time.Duration(*warmupDur) * time.Second)
endDeadline := warmDeadline.Add(time.Duration(*duration) * time.Second)
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
for _, cc := range ccs {
runWithConn(cc, req, warmDeadline, endDeadline)
}
wg.Wait()
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Error writing memory profile: %v", err)
}
hist := stats.NewHistogram(hopts)
for _, h := range hists {
hist.Merge(h)
}
parseHist(hist)
fmt.Println("Client CPU utilization:", cpu)
fmt.Println("Client CPU profile:", cf.Name())
fmt.Println("Client Mem Profile:", mf.Name())
}
func buildConnections(ctx context.Context) []*grpc.ClientConn {
ccs := make([]*grpc.ClientConn, *numConn)
for i := range ccs {
ccs[i] = benchmark.NewClientConnWithContext(ctx, "localhost:"+*port, grpc.WithInsecure(), grpc.WithBlock())
}
return ccs
}
func runWithConn(cc *grpc.ClientConn, req *testpb.SimpleRequest, warmDeadline, endDeadline time.Time) {
for i := 0; i < *numRPC; i++ {
wg.Add(1)
go func() {
defer wg.Done()
caller := makeCaller(cc, req)
hist := stats.NewHistogram(hopts)
for {
start := time.Now()
if start.After(endDeadline) {
mu.Lock()
hists = append(hists, hist)
mu.Unlock()
return
}
caller()
elapsed := time.Since(start)
if start.After(warmDeadline) {
hist.Add(elapsed.Nanoseconds())
}
}
}()
}
}
func makeCaller(cc *grpc.ClientConn, req *testpb.SimpleRequest) func() {
client := testpb.NewBenchmarkServiceClient(cc)
if *rpcType == "unary" {
return func() {
if _, err := client.UnaryCall(context.Background(), req); err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
}
}
stream, err := client.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
return func() {
if err := stream.Send(req); err != nil {
grpclog.Fatalf("Streaming RPC failed to send: %v", err)
}
if _, err := stream.Recv(); err != nil {
grpclog.Fatalf("Streaming RPC failed to read: %v", err)
}
}
}
func parseHist(hist *stats.Histogram) {
fmt.Println("qps:", float64(hist.Count)/float64(*duration))
fmt.Printf("Latency: (50/90/99 %%ile): %v/%v/%v\n",
time.Duration(median(.5, hist)),
time.Duration(median(.9, hist)),
time.Duration(median(.99, hist)))
}
func median(percentile float64, h *stats.Histogram) int64 {
need := int64(float64(h.Count) * percentile)
have := int64(0)
for _, bucket := range h.Buckets {
count := bucket.Count
if have+count >= need {
percent := float64(need-have) / float64(count)
return int64((1.0-percent)*bucket.LowBound + percent*bucket.LowBound*(1.0+hopts.GrowthFactor))
}
have += bucket.Count
}
panic("should have found a bound")
}

File diff suppressed because it is too large Load Diff

View File

@ -1,186 +0,0 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
import "payloads.proto";
import "stats.proto";
package grpc.testing;
enum ClientType {
SYNC_CLIENT = 0;
ASYNC_CLIENT = 1;
}
enum ServerType {
SYNC_SERVER = 0;
ASYNC_SERVER = 1;
ASYNC_GENERIC_SERVER = 2;
}
enum RpcType {
UNARY = 0;
STREAMING = 1;
}
// Parameters of poisson process distribution, which is a good representation
// of activity coming in from independent identical stationary sources.
message PoissonParams {
// The rate of arrivals (a.k.a. lambda parameter of the exp distribution).
double offered_load = 1;
}
message UniformParams {
double interarrival_lo = 1;
double interarrival_hi = 2;
}
message DeterministicParams {
double offered_load = 1;
}
message ParetoParams {
double interarrival_base = 1;
double alpha = 2;
}
// Once an RPC finishes, immediately start a new one.
// No configuration parameters needed.
message ClosedLoopParams {
}
message LoadParams {
oneof load {
ClosedLoopParams closed_loop = 1;
PoissonParams poisson = 2;
UniformParams uniform = 3;
DeterministicParams determ = 4;
ParetoParams pareto = 5;
};
}
// presence of SecurityParams implies use of TLS
message SecurityParams {
bool use_test_ca = 1;
string server_host_override = 2;
}
message ClientConfig {
// List of targets to connect to. At least one target needs to be specified.
repeated string server_targets = 1;
ClientType client_type = 2;
SecurityParams security_params = 3;
// How many concurrent RPCs to start for each channel.
// For synchronous client, use a separate thread for each outstanding RPC.
int32 outstanding_rpcs_per_channel = 4;
// Number of independent client channels to create.
// i-th channel will connect to server_target[i % server_targets.size()]
int32 client_channels = 5;
// Only for async client. Number of threads to use to start/manage RPCs.
int32 async_client_threads = 7;
RpcType rpc_type = 8;
// The requested load for the entire client (aggregated over all the threads).
LoadParams load_params = 10;
PayloadConfig payload_config = 11;
HistogramParams histogram_params = 12;
// Specify the cores we should run the client on, if desired
repeated int32 core_list = 13;
int32 core_limit = 14;
}
message ClientStatus {
ClientStats stats = 1;
}
// Request current stats
message Mark {
// if true, the stats will be reset after taking their snapshot.
bool reset = 1;
}
message ClientArgs {
oneof argtype {
ClientConfig setup = 1;
Mark mark = 2;
}
}
message ServerConfig {
ServerType server_type = 1;
SecurityParams security_params = 2;
// Port on which to listen. Zero means pick unused port.
int32 port = 4;
// Only for async server. Number of threads used to serve the requests.
int32 async_server_threads = 7;
// Specify the number of cores to limit server to, if desired
int32 core_limit = 8;
// payload config, used in generic server
PayloadConfig payload_config = 9;
// Specify the cores we should run the server on, if desired
repeated int32 core_list = 10;
}
message ServerArgs {
oneof argtype {
ServerConfig setup = 1;
Mark mark = 2;
}
}
message ServerStatus {
ServerStats stats = 1;
// the port bound by the server
int32 port = 2;
// Number of cores available to the server
int32 cores = 3;
}
message CoreRequest {
}
message CoreResponse {
// Number of cores available on the server
int32 cores = 1;
}
message Void {
}
// A single performance scenario: input to qps_json_driver
message Scenario {
// Human readable name for this scenario
string name = 1;
// Client configuration
ClientConfig client_config = 2;
// Number of clients to start for the test
int32 num_clients = 3;
// Server configuration
ServerConfig server_config = 4;
// Number of servers to start for the test
int32 num_servers = 5;
// Warmup period, in seconds
int32 warmup_seconds = 6;
// Benchmark time, in seconds
int32 benchmark_seconds = 7;
// Number of workers to spawn locally (usually zero)
int32 spawn_local_worker_count = 8;
}
// A set of scenarios to be run with qps_json_driver
message Scenarios {
repeated Scenario scenarios = 1;
}

View File

@ -1,731 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: messages.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The type of payload that should be returned.
type PayloadType int32
const (
// Compressable text format.
PayloadType_COMPRESSABLE PayloadType = 0
// Uncompressable binary format.
PayloadType_UNCOMPRESSABLE PayloadType = 1
// Randomly chosen from all other formats defined in this enum.
PayloadType_RANDOM PayloadType = 2
)
var PayloadType_name = map[int32]string{
0: "COMPRESSABLE",
1: "UNCOMPRESSABLE",
2: "RANDOM",
}
var PayloadType_value = map[string]int32{
"COMPRESSABLE": 0,
"UNCOMPRESSABLE": 1,
"RANDOM": 2,
}
func (x PayloadType) String() string {
return proto.EnumName(PayloadType_name, int32(x))
}
func (PayloadType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
// Compression algorithms
type CompressionType int32
const (
// No compression
CompressionType_NONE CompressionType = 0
CompressionType_GZIP CompressionType = 1
CompressionType_DEFLATE CompressionType = 2
)
var CompressionType_name = map[int32]string{
0: "NONE",
1: "GZIP",
2: "DEFLATE",
}
var CompressionType_value = map[string]int32{
"NONE": 0,
"GZIP": 1,
"DEFLATE": 2,
}
func (x CompressionType) String() string {
return proto.EnumName(CompressionType_name, int32(x))
}
func (CompressionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
// A block of data, to simply increase gRPC message size.
type Payload struct {
// The type of data in body.
Type PayloadType `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.testing.PayloadType" json:"type,omitempty"`
// Primary contents of payload.
Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Payload) Reset() { *m = Payload{} }
func (m *Payload) String() string { return proto.CompactTextString(m) }
func (*Payload) ProtoMessage() {}
func (*Payload) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
func (m *Payload) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Payload.Unmarshal(m, b)
}
func (m *Payload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Payload.Marshal(b, m, deterministic)
}
func (dst *Payload) XXX_Merge(src proto.Message) {
xxx_messageInfo_Payload.Merge(dst, src)
}
func (m *Payload) XXX_Size() int {
return xxx_messageInfo_Payload.Size(m)
}
func (m *Payload) XXX_DiscardUnknown() {
xxx_messageInfo_Payload.DiscardUnknown(m)
}
var xxx_messageInfo_Payload proto.InternalMessageInfo
func (m *Payload) GetType() PayloadType {
if m != nil {
return m.Type
}
return PayloadType_COMPRESSABLE
}
func (m *Payload) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
// A protobuf representation for grpc status. This is used by test
// clients to specify a status that the server should attempt to return.
type EchoStatus struct {
Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EchoStatus) Reset() { *m = EchoStatus{} }
func (m *EchoStatus) String() string { return proto.CompactTextString(m) }
func (*EchoStatus) ProtoMessage() {}
func (*EchoStatus) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
func (m *EchoStatus) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EchoStatus.Unmarshal(m, b)
}
func (m *EchoStatus) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EchoStatus.Marshal(b, m, deterministic)
}
func (dst *EchoStatus) XXX_Merge(src proto.Message) {
xxx_messageInfo_EchoStatus.Merge(dst, src)
}
func (m *EchoStatus) XXX_Size() int {
return xxx_messageInfo_EchoStatus.Size(m)
}
func (m *EchoStatus) XXX_DiscardUnknown() {
xxx_messageInfo_EchoStatus.DiscardUnknown(m)
}
var xxx_messageInfo_EchoStatus proto.InternalMessageInfo
func (m *EchoStatus) GetCode() int32 {
if m != nil {
return m.Code
}
return 0
}
func (m *EchoStatus) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
// Unary request.
type SimpleRequest struct {
// Desired payload type in the response from the server.
// If response_type is RANDOM, server randomly chooses one from other formats.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Desired payload size in the response from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
ResponseSize int32 `protobuf:"varint,2,opt,name=response_size,json=responseSize,proto3" json:"response_size,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Whether SimpleResponse should include username.
FillUsername bool `protobuf:"varint,4,opt,name=fill_username,json=fillUsername,proto3" json:"fill_username,omitempty"`
// Whether SimpleResponse should include OAuth scope.
FillOauthScope bool `protobuf:"varint,5,opt,name=fill_oauth_scope,json=fillOauthScope,proto3" json:"fill_oauth_scope,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleRequest) Reset() { *m = SimpleRequest{} }
func (m *SimpleRequest) String() string { return proto.CompactTextString(m) }
func (*SimpleRequest) ProtoMessage() {}
func (*SimpleRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{2}
}
func (m *SimpleRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleRequest.Unmarshal(m, b)
}
func (m *SimpleRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleRequest.Marshal(b, m, deterministic)
}
func (dst *SimpleRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleRequest.Merge(dst, src)
}
func (m *SimpleRequest) XXX_Size() int {
return xxx_messageInfo_SimpleRequest.Size(m)
}
func (m *SimpleRequest) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleRequest.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleRequest proto.InternalMessageInfo
func (m *SimpleRequest) GetResponseType() PayloadType {
if m != nil {
return m.ResponseType
}
return PayloadType_COMPRESSABLE
}
func (m *SimpleRequest) GetResponseSize() int32 {
if m != nil {
return m.ResponseSize
}
return 0
}
func (m *SimpleRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *SimpleRequest) GetFillUsername() bool {
if m != nil {
return m.FillUsername
}
return false
}
func (m *SimpleRequest) GetFillOauthScope() bool {
if m != nil {
return m.FillOauthScope
}
return false
}
func (m *SimpleRequest) GetResponseCompression() CompressionType {
if m != nil {
return m.ResponseCompression
}
return CompressionType_NONE
}
func (m *SimpleRequest) GetResponseStatus() *EchoStatus {
if m != nil {
return m.ResponseStatus
}
return nil
}
// Unary response, as configured by the request.
type SimpleResponse struct {
// Payload to increase message size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
// The user the request came from, for verifying authentication was
// successful when the client expected it.
Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"`
// OAuth scope.
OauthScope string `protobuf:"bytes,3,opt,name=oauth_scope,json=oauthScope,proto3" json:"oauth_scope,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleResponse) Reset() { *m = SimpleResponse{} }
func (m *SimpleResponse) String() string { return proto.CompactTextString(m) }
func (*SimpleResponse) ProtoMessage() {}
func (*SimpleResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{3}
}
func (m *SimpleResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleResponse.Unmarshal(m, b)
}
func (m *SimpleResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleResponse.Marshal(b, m, deterministic)
}
func (dst *SimpleResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleResponse.Merge(dst, src)
}
func (m *SimpleResponse) XXX_Size() int {
return xxx_messageInfo_SimpleResponse.Size(m)
}
func (m *SimpleResponse) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleResponse.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleResponse proto.InternalMessageInfo
func (m *SimpleResponse) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *SimpleResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *SimpleResponse) GetOauthScope() string {
if m != nil {
return m.OauthScope
}
return ""
}
// Client-streaming request.
type StreamingInputCallRequest struct {
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} }
func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallRequest) ProtoMessage() {}
func (*StreamingInputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{4}
}
func (m *StreamingInputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallRequest.Unmarshal(m, b)
}
func (m *StreamingInputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallRequest.Merge(dst, src)
}
func (m *StreamingInputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallRequest.Size(m)
}
func (m *StreamingInputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallRequest proto.InternalMessageInfo
func (m *StreamingInputCallRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
// Client-streaming response.
type StreamingInputCallResponse struct {
// Aggregated size of payloads received from the client.
AggregatedPayloadSize int32 `protobuf:"varint,1,opt,name=aggregated_payload_size,json=aggregatedPayloadSize,proto3" json:"aggregated_payload_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} }
func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallResponse) ProtoMessage() {}
func (*StreamingInputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{5}
}
func (m *StreamingInputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallResponse.Unmarshal(m, b)
}
func (m *StreamingInputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallResponse.Merge(dst, src)
}
func (m *StreamingInputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallResponse.Size(m)
}
func (m *StreamingInputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallResponse proto.InternalMessageInfo
func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 {
if m != nil {
return m.AggregatedPayloadSize
}
return 0
}
// Configuration for a particular response.
type ResponseParameters struct {
// Desired payload sizes in responses from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
Size int32 `protobuf:"varint,1,opt,name=size,proto3" json:"size,omitempty"`
// Desired interval between consecutive responses in the response stream in
// microseconds.
IntervalUs int32 `protobuf:"varint,2,opt,name=interval_us,json=intervalUs,proto3" json:"interval_us,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ResponseParameters) Reset() { *m = ResponseParameters{} }
func (m *ResponseParameters) String() string { return proto.CompactTextString(m) }
func (*ResponseParameters) ProtoMessage() {}
func (*ResponseParameters) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{6}
}
func (m *ResponseParameters) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ResponseParameters.Unmarshal(m, b)
}
func (m *ResponseParameters) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ResponseParameters.Marshal(b, m, deterministic)
}
func (dst *ResponseParameters) XXX_Merge(src proto.Message) {
xxx_messageInfo_ResponseParameters.Merge(dst, src)
}
func (m *ResponseParameters) XXX_Size() int {
return xxx_messageInfo_ResponseParameters.Size(m)
}
func (m *ResponseParameters) XXX_DiscardUnknown() {
xxx_messageInfo_ResponseParameters.DiscardUnknown(m)
}
var xxx_messageInfo_ResponseParameters proto.InternalMessageInfo
func (m *ResponseParameters) GetSize() int32 {
if m != nil {
return m.Size
}
return 0
}
func (m *ResponseParameters) GetIntervalUs() int32 {
if m != nil {
return m.IntervalUs
}
return 0
}
// Server-streaming request.
type StreamingOutputCallRequest struct {
// Desired payload type in the response from the server.
// If response_type is RANDOM, the payload from each response in the stream
// might be of different types. This is to simulate a mixed type of payload
// stream.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Configuration for each expected response message.
ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters,json=responseParameters,proto3" json:"response_parameters,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} }
func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallRequest) ProtoMessage() {}
func (*StreamingOutputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{7}
}
func (m *StreamingOutputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallRequest.Unmarshal(m, b)
}
func (m *StreamingOutputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallRequest.Merge(dst, src)
}
func (m *StreamingOutputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallRequest.Size(m)
}
func (m *StreamingOutputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallRequest proto.InternalMessageInfo
func (m *StreamingOutputCallRequest) GetResponseType() PayloadType {
if m != nil {
return m.ResponseType
}
return PayloadType_COMPRESSABLE
}
func (m *StreamingOutputCallRequest) GetResponseParameters() []*ResponseParameters {
if m != nil {
return m.ResponseParameters
}
return nil
}
func (m *StreamingOutputCallRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *StreamingOutputCallRequest) GetResponseCompression() CompressionType {
if m != nil {
return m.ResponseCompression
}
return CompressionType_NONE
}
func (m *StreamingOutputCallRequest) GetResponseStatus() *EchoStatus {
if m != nil {
return m.ResponseStatus
}
return nil
}
// Server-streaming response, as configured by the request and parameters.
type StreamingOutputCallResponse struct {
// Payload to increase response size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} }
func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallResponse) ProtoMessage() {}
func (*StreamingOutputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{8}
}
func (m *StreamingOutputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallResponse.Unmarshal(m, b)
}
func (m *StreamingOutputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallResponse.Merge(dst, src)
}
func (m *StreamingOutputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallResponse.Size(m)
}
func (m *StreamingOutputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallResponse proto.InternalMessageInfo
func (m *StreamingOutputCallResponse) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
// For reconnect interop test only.
// Client tells server what reconnection parameters it used.
type ReconnectParams struct {
MaxReconnectBackoffMs int32 `protobuf:"varint,1,opt,name=max_reconnect_backoff_ms,json=maxReconnectBackoffMs,proto3" json:"max_reconnect_backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectParams) Reset() { *m = ReconnectParams{} }
func (m *ReconnectParams) String() string { return proto.CompactTextString(m) }
func (*ReconnectParams) ProtoMessage() {}
func (*ReconnectParams) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{9}
}
func (m *ReconnectParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectParams.Unmarshal(m, b)
}
func (m *ReconnectParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectParams.Marshal(b, m, deterministic)
}
func (dst *ReconnectParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectParams.Merge(dst, src)
}
func (m *ReconnectParams) XXX_Size() int {
return xxx_messageInfo_ReconnectParams.Size(m)
}
func (m *ReconnectParams) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectParams.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectParams proto.InternalMessageInfo
func (m *ReconnectParams) GetMaxReconnectBackoffMs() int32 {
if m != nil {
return m.MaxReconnectBackoffMs
}
return 0
}
// For reconnect interop test only.
// Server tells client whether its reconnects are following the spec and the
// reconnect backoffs it saw.
type ReconnectInfo struct {
Passed bool `protobuf:"varint,1,opt,name=passed,proto3" json:"passed,omitempty"`
BackoffMs []int32 `protobuf:"varint,2,rep,packed,name=backoff_ms,json=backoffMs,proto3" json:"backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectInfo) Reset() { *m = ReconnectInfo{} }
func (m *ReconnectInfo) String() string { return proto.CompactTextString(m) }
func (*ReconnectInfo) ProtoMessage() {}
func (*ReconnectInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{10}
}
func (m *ReconnectInfo) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectInfo.Unmarshal(m, b)
}
func (m *ReconnectInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectInfo.Marshal(b, m, deterministic)
}
func (dst *ReconnectInfo) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectInfo.Merge(dst, src)
}
func (m *ReconnectInfo) XXX_Size() int {
return xxx_messageInfo_ReconnectInfo.Size(m)
}
func (m *ReconnectInfo) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectInfo.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectInfo proto.InternalMessageInfo
func (m *ReconnectInfo) GetPassed() bool {
if m != nil {
return m.Passed
}
return false
}
func (m *ReconnectInfo) GetBackoffMs() []int32 {
if m != nil {
return m.BackoffMs
}
return nil
}
func init() {
proto.RegisterType((*Payload)(nil), "grpc.testing.Payload")
proto.RegisterType((*EchoStatus)(nil), "grpc.testing.EchoStatus")
proto.RegisterType((*SimpleRequest)(nil), "grpc.testing.SimpleRequest")
proto.RegisterType((*SimpleResponse)(nil), "grpc.testing.SimpleResponse")
proto.RegisterType((*StreamingInputCallRequest)(nil), "grpc.testing.StreamingInputCallRequest")
proto.RegisterType((*StreamingInputCallResponse)(nil), "grpc.testing.StreamingInputCallResponse")
proto.RegisterType((*ResponseParameters)(nil), "grpc.testing.ResponseParameters")
proto.RegisterType((*StreamingOutputCallRequest)(nil), "grpc.testing.StreamingOutputCallRequest")
proto.RegisterType((*StreamingOutputCallResponse)(nil), "grpc.testing.StreamingOutputCallResponse")
proto.RegisterType((*ReconnectParams)(nil), "grpc.testing.ReconnectParams")
proto.RegisterType((*ReconnectInfo)(nil), "grpc.testing.ReconnectInfo")
proto.RegisterEnum("grpc.testing.PayloadType", PayloadType_name, PayloadType_value)
proto.RegisterEnum("grpc.testing.CompressionType", CompressionType_name, CompressionType_value)
}
func init() { proto.RegisterFile("messages.proto", fileDescriptor_messages_5c70222ad96bf232) }
var fileDescriptor_messages_5c70222ad96bf232 = []byte{
// 652 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xcc, 0x55, 0x4d, 0x6f, 0xd3, 0x40,
0x10, 0xc5, 0xf9, 0xee, 0x24, 0x4d, 0xa3, 0x85, 0x82, 0x5b, 0x54, 0x11, 0x99, 0x4b, 0x54, 0x89,
0x20, 0x05, 0x09, 0x24, 0x0e, 0xa0, 0xb4, 0x4d, 0x51, 0x50, 0x9a, 0x84, 0x75, 0x7b, 0xe1, 0x62,
0x6d, 0x9c, 0x8d, 0x6b, 0x11, 0x7b, 0x8d, 0x77, 0x8d, 0x9a, 0x1e, 0xb8, 0xf3, 0x83, 0xb9, 0xa3,
0x5d, 0x7f, 0xc4, 0x69, 0x7b, 0x68, 0xe1, 0xc2, 0x6d, 0xf7, 0xed, 0x9b, 0x97, 0x79, 0x33, 0xcf,
0x0a, 0x34, 0x3d, 0xca, 0x39, 0x71, 0x28, 0xef, 0x06, 0x21, 0x13, 0x0c, 0x35, 0x9c, 0x30, 0xb0,
0xbb, 0x82, 0x72, 0xe1, 0xfa, 0x8e, 0x31, 0x82, 0xea, 0x94, 0xac, 0x96, 0x8c, 0xcc, 0xd1, 0x2b,
0x28, 0x89, 0x55, 0x40, 0x75, 0xad, 0xad, 0x75, 0x9a, 0xbd, 0xbd, 0x6e, 0x9e, 0xd7, 0x4d, 0x48,
0xe7, 0xab, 0x80, 0x62, 0x45, 0x43, 0x08, 0x4a, 0x33, 0x36, 0x5f, 0xe9, 0x85, 0xb6, 0xd6, 0x69,
0x60, 0x75, 0x36, 0xde, 0x03, 0x0c, 0xec, 0x4b, 0x66, 0x0a, 0x22, 0x22, 0x2e, 0x19, 0x36, 0x9b,
0xc7, 0x82, 0x65, 0xac, 0xce, 0x48, 0x87, 0x6a, 0xd2, 0x8f, 0x2a, 0xdc, 0xc2, 0xe9, 0xd5, 0xf8,
0x55, 0x84, 0x6d, 0xd3, 0xf5, 0x82, 0x25, 0xc5, 0xf4, 0x7b, 0x44, 0xb9, 0x40, 0x1f, 0x60, 0x3b,
0xa4, 0x3c, 0x60, 0x3e, 0xa7, 0xd6, 0xfd, 0x3a, 0x6b, 0xa4, 0x7c, 0x79, 0x43, 0x2f, 0x73, 0xf5,
0xdc, 0xbd, 0x8e, 0x7f, 0xb1, 0xbc, 0x26, 0x99, 0xee, 0x35, 0x45, 0xaf, 0xa1, 0x1a, 0xc4, 0x0a,
0x7a, 0xb1, 0xad, 0x75, 0xea, 0xbd, 0xdd, 0x3b, 0xe5, 0x71, 0xca, 0x92, 0xaa, 0x0b, 0x77, 0xb9,
0xb4, 0x22, 0x4e, 0x43, 0x9f, 0x78, 0x54, 0x2f, 0xb5, 0xb5, 0x4e, 0x0d, 0x37, 0x24, 0x78, 0x91,
0x60, 0xa8, 0x03, 0x2d, 0x45, 0x62, 0x24, 0x12, 0x97, 0x16, 0xb7, 0x59, 0x40, 0xf5, 0xb2, 0xe2,
0x35, 0x25, 0x3e, 0x91, 0xb0, 0x29, 0x51, 0x34, 0x85, 0x27, 0x59, 0x93, 0x36, 0xf3, 0x82, 0x90,
0x72, 0xee, 0x32, 0x5f, 0xaf, 0x28, 0xaf, 0x07, 0x9b, 0xcd, 0x1c, 0xaf, 0x09, 0xca, 0xef, 0xe3,
0xb4, 0x34, 0xf7, 0x80, 0xfa, 0xb0, 0xb3, 0xb6, 0xad, 0x36, 0xa1, 0x57, 0x95, 0x33, 0x7d, 0x53,
0x6c, 0xbd, 0x29, 0xdc, 0xcc, 0x46, 0xa2, 0xee, 0xc6, 0x4f, 0x68, 0xa6, 0xab, 0x88, 0xf1, 0xfc,
0x98, 0xb4, 0x7b, 0x8d, 0x69, 0x1f, 0x6a, 0xd9, 0x84, 0xe2, 0x4d, 0x67, 0x77, 0xf4, 0x02, 0xea,
0xf9, 0xc1, 0x14, 0xd5, 0x33, 0xb0, 0x6c, 0x28, 0xc6, 0x08, 0xf6, 0x4c, 0x11, 0x52, 0xe2, 0xb9,
0xbe, 0x33, 0xf4, 0x83, 0x48, 0x1c, 0x93, 0xe5, 0x32, 0x8d, 0xc5, 0x43, 0x5b, 0x31, 0xce, 0x61,
0xff, 0x2e, 0xb5, 0xc4, 0xd9, 0x5b, 0x78, 0x46, 0x1c, 0x27, 0xa4, 0x0e, 0x11, 0x74, 0x6e, 0x25,
0x35, 0x71, 0x5e, 0xe2, 0xe0, 0xee, 0xae, 0x9f, 0x13, 0x69, 0x19, 0x1c, 0x63, 0x08, 0x28, 0xd5,
0x98, 0x92, 0x90, 0x78, 0x54, 0xd0, 0x50, 0x65, 0x3e, 0x57, 0xaa, 0xce, 0xd2, 0xae, 0xeb, 0x0b,
0x1a, 0xfe, 0x20, 0x32, 0x35, 0x49, 0x0a, 0x21, 0x85, 0x2e, 0xb8, 0xf1, 0xbb, 0x90, 0xeb, 0x70,
0x12, 0x89, 0x1b, 0x86, 0xff, 0xf5, 0x3b, 0xf8, 0x02, 0x59, 0x4e, 0xac, 0x20, 0x6b, 0x55, 0x2f,
0xb4, 0x8b, 0x9d, 0x7a, 0xaf, 0xbd, 0xa9, 0x72, 0xdb, 0x12, 0x46, 0xe1, 0x6d, 0x9b, 0x0f, 0xfe,
0x6a, 0xfe, 0xcb, 0x98, 0x8f, 0xe1, 0xf9, 0x9d, 0x63, 0xff, 0xcb, 0xcc, 0x1b, 0x9f, 0x61, 0x07,
0x53, 0x9b, 0xf9, 0x3e, 0xb5, 0x85, 0x1a, 0x16, 0x47, 0xef, 0x40, 0xf7, 0xc8, 0x95, 0x15, 0xa6,
0xb0, 0x35, 0x23, 0xf6, 0x37, 0xb6, 0x58, 0x58, 0x1e, 0x4f, 0xe3, 0xe5, 0x91, 0xab, 0xac, 0xea,
0x28, 0x7e, 0x3d, 0xe3, 0xc6, 0x29, 0x6c, 0x67, 0xe8, 0xd0, 0x5f, 0x30, 0xf4, 0x14, 0x2a, 0x01,
0xe1, 0x9c, 0xc6, 0xcd, 0xd4, 0x70, 0x72, 0x43, 0x07, 0x00, 0x39, 0x4d, 0xb9, 0xd4, 0x32, 0xde,
0x9a, 0xa5, 0x3a, 0x87, 0x1f, 0xa1, 0x9e, 0x4b, 0x06, 0x6a, 0x41, 0xe3, 0x78, 0x72, 0x36, 0xc5,
0x03, 0xd3, 0xec, 0x1f, 0x8d, 0x06, 0xad, 0x47, 0x08, 0x41, 0xf3, 0x62, 0xbc, 0x81, 0x69, 0x08,
0xa0, 0x82, 0xfb, 0xe3, 0x93, 0xc9, 0x59, 0xab, 0x70, 0xd8, 0x83, 0x9d, 0x1b, 0xfb, 0x40, 0x35,
0x28, 0x8d, 0x27, 0x63, 0x59, 0x5c, 0x83, 0xd2, 0xa7, 0xaf, 0xc3, 0x69, 0x4b, 0x43, 0x75, 0xa8,
0x9e, 0x0c, 0x4e, 0x47, 0xfd, 0xf3, 0x41, 0xab, 0x30, 0xab, 0xa8, 0xbf, 0x9a, 0x37, 0x7f, 0x02,
0x00, 0x00, 0xff, 0xff, 0xc2, 0x6a, 0xce, 0x1e, 0x7c, 0x06, 0x00, 0x00,
}

View File

@ -1,157 +0,0 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Message definitions to be used by integration test service definitions.
syntax = "proto3";
package grpc.testing;
// The type of payload that should be returned.
enum PayloadType {
// Compressable text format.
COMPRESSABLE = 0;
// Uncompressable binary format.
UNCOMPRESSABLE = 1;
// Randomly chosen from all other formats defined in this enum.
RANDOM = 2;
}
// Compression algorithms
enum CompressionType {
// No compression
NONE = 0;
GZIP = 1;
DEFLATE = 2;
}
// A block of data, to simply increase gRPC message size.
message Payload {
// The type of data in body.
PayloadType type = 1;
// Primary contents of payload.
bytes body = 2;
}
// A protobuf representation for grpc status. This is used by test
// clients to specify a status that the server should attempt to return.
message EchoStatus {
int32 code = 1;
string message = 2;
}
// Unary request.
message SimpleRequest {
// Desired payload type in the response from the server.
// If response_type is RANDOM, server randomly chooses one from other formats.
PayloadType response_type = 1;
// Desired payload size in the response from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
int32 response_size = 2;
// Optional input payload sent along with the request.
Payload payload = 3;
// Whether SimpleResponse should include username.
bool fill_username = 4;
// Whether SimpleResponse should include OAuth scope.
bool fill_oauth_scope = 5;
// Compression algorithm to be used by the server for the response (stream)
CompressionType response_compression = 6;
// Whether server should return a given status
EchoStatus response_status = 7;
}
// Unary response, as configured by the request.
message SimpleResponse {
// Payload to increase message size.
Payload payload = 1;
// The user the request came from, for verifying authentication was
// successful when the client expected it.
string username = 2;
// OAuth scope.
string oauth_scope = 3;
}
// Client-streaming request.
message StreamingInputCallRequest {
// Optional input payload sent along with the request.
Payload payload = 1;
// Not expecting any payload from the response.
}
// Client-streaming response.
message StreamingInputCallResponse {
// Aggregated size of payloads received from the client.
int32 aggregated_payload_size = 1;
}
// Configuration for a particular response.
message ResponseParameters {
// Desired payload sizes in responses from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
int32 size = 1;
// Desired interval between consecutive responses in the response stream in
// microseconds.
int32 interval_us = 2;
}
// Server-streaming request.
message StreamingOutputCallRequest {
// Desired payload type in the response from the server.
// If response_type is RANDOM, the payload from each response in the stream
// might be of different types. This is to simulate a mixed type of payload
// stream.
PayloadType response_type = 1;
// Configuration for each expected response message.
repeated ResponseParameters response_parameters = 2;
// Optional input payload sent along with the request.
Payload payload = 3;
// Compression algorithm to be used by the server for the response (stream)
CompressionType response_compression = 6;
// Whether server should return a given status
EchoStatus response_status = 7;
}
// Server-streaming response, as configured by the request and parameters.
message StreamingOutputCallResponse {
// Payload to increase response size.
Payload payload = 1;
}
// For reconnect interop test only.
// Client tells server what reconnection parameters it used.
message ReconnectParams {
int32 max_reconnect_backoff_ms = 1;
}
// For reconnect interop test only.
// Server tells client whether its reconnects are following the spec and the
// reconnect backoffs it saw.
message ReconnectInfo {
bool passed = 1;
repeated int32 backoff_ms = 2;
}

View File

@ -1,348 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: payloads.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ByteBufferParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ByteBufferParams) Reset() { *m = ByteBufferParams{} }
func (m *ByteBufferParams) String() string { return proto.CompactTextString(m) }
func (*ByteBufferParams) ProtoMessage() {}
func (*ByteBufferParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{0}
}
func (m *ByteBufferParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ByteBufferParams.Unmarshal(m, b)
}
func (m *ByteBufferParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ByteBufferParams.Marshal(b, m, deterministic)
}
func (dst *ByteBufferParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ByteBufferParams.Merge(dst, src)
}
func (m *ByteBufferParams) XXX_Size() int {
return xxx_messageInfo_ByteBufferParams.Size(m)
}
func (m *ByteBufferParams) XXX_DiscardUnknown() {
xxx_messageInfo_ByteBufferParams.DiscardUnknown(m)
}
var xxx_messageInfo_ByteBufferParams proto.InternalMessageInfo
func (m *ByteBufferParams) GetReqSize() int32 {
if m != nil {
return m.ReqSize
}
return 0
}
func (m *ByteBufferParams) GetRespSize() int32 {
if m != nil {
return m.RespSize
}
return 0
}
type SimpleProtoParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleProtoParams) Reset() { *m = SimpleProtoParams{} }
func (m *SimpleProtoParams) String() string { return proto.CompactTextString(m) }
func (*SimpleProtoParams) ProtoMessage() {}
func (*SimpleProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{1}
}
func (m *SimpleProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleProtoParams.Unmarshal(m, b)
}
func (m *SimpleProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleProtoParams.Marshal(b, m, deterministic)
}
func (dst *SimpleProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleProtoParams.Merge(dst, src)
}
func (m *SimpleProtoParams) XXX_Size() int {
return xxx_messageInfo_SimpleProtoParams.Size(m)
}
func (m *SimpleProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleProtoParams proto.InternalMessageInfo
func (m *SimpleProtoParams) GetReqSize() int32 {
if m != nil {
return m.ReqSize
}
return 0
}
func (m *SimpleProtoParams) GetRespSize() int32 {
if m != nil {
return m.RespSize
}
return 0
}
type ComplexProtoParams struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ComplexProtoParams) Reset() { *m = ComplexProtoParams{} }
func (m *ComplexProtoParams) String() string { return proto.CompactTextString(m) }
func (*ComplexProtoParams) ProtoMessage() {}
func (*ComplexProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{2}
}
func (m *ComplexProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ComplexProtoParams.Unmarshal(m, b)
}
func (m *ComplexProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ComplexProtoParams.Marshal(b, m, deterministic)
}
func (dst *ComplexProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ComplexProtoParams.Merge(dst, src)
}
func (m *ComplexProtoParams) XXX_Size() int {
return xxx_messageInfo_ComplexProtoParams.Size(m)
}
func (m *ComplexProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_ComplexProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_ComplexProtoParams proto.InternalMessageInfo
type PayloadConfig struct {
// Types that are valid to be assigned to Payload:
// *PayloadConfig_BytebufParams
// *PayloadConfig_SimpleParams
// *PayloadConfig_ComplexParams
Payload isPayloadConfig_Payload `protobuf_oneof:"payload"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *PayloadConfig) Reset() { *m = PayloadConfig{} }
func (m *PayloadConfig) String() string { return proto.CompactTextString(m) }
func (*PayloadConfig) ProtoMessage() {}
func (*PayloadConfig) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{3}
}
func (m *PayloadConfig) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_PayloadConfig.Unmarshal(m, b)
}
func (m *PayloadConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_PayloadConfig.Marshal(b, m, deterministic)
}
func (dst *PayloadConfig) XXX_Merge(src proto.Message) {
xxx_messageInfo_PayloadConfig.Merge(dst, src)
}
func (m *PayloadConfig) XXX_Size() int {
return xxx_messageInfo_PayloadConfig.Size(m)
}
func (m *PayloadConfig) XXX_DiscardUnknown() {
xxx_messageInfo_PayloadConfig.DiscardUnknown(m)
}
var xxx_messageInfo_PayloadConfig proto.InternalMessageInfo
type isPayloadConfig_Payload interface {
isPayloadConfig_Payload()
}
type PayloadConfig_BytebufParams struct {
BytebufParams *ByteBufferParams `protobuf:"bytes,1,opt,name=bytebuf_params,json=bytebufParams,proto3,oneof"`
}
type PayloadConfig_SimpleParams struct {
SimpleParams *SimpleProtoParams `protobuf:"bytes,2,opt,name=simple_params,json=simpleParams,proto3,oneof"`
}
type PayloadConfig_ComplexParams struct {
ComplexParams *ComplexProtoParams `protobuf:"bytes,3,opt,name=complex_params,json=complexParams,proto3,oneof"`
}
func (*PayloadConfig_BytebufParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_SimpleParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_ComplexParams) isPayloadConfig_Payload() {}
func (m *PayloadConfig) GetPayload() isPayloadConfig_Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *PayloadConfig) GetBytebufParams() *ByteBufferParams {
if x, ok := m.GetPayload().(*PayloadConfig_BytebufParams); ok {
return x.BytebufParams
}
return nil
}
func (m *PayloadConfig) GetSimpleParams() *SimpleProtoParams {
if x, ok := m.GetPayload().(*PayloadConfig_SimpleParams); ok {
return x.SimpleParams
}
return nil
}
func (m *PayloadConfig) GetComplexParams() *ComplexProtoParams {
if x, ok := m.GetPayload().(*PayloadConfig_ComplexParams); ok {
return x.ComplexParams
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*PayloadConfig) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _PayloadConfig_OneofMarshaler, _PayloadConfig_OneofUnmarshaler, _PayloadConfig_OneofSizer, []interface{}{
(*PayloadConfig_BytebufParams)(nil),
(*PayloadConfig_SimpleParams)(nil),
(*PayloadConfig_ComplexParams)(nil),
}
}
func _PayloadConfig_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*PayloadConfig)
// payload
switch x := m.Payload.(type) {
case *PayloadConfig_BytebufParams:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.BytebufParams); err != nil {
return err
}
case *PayloadConfig_SimpleParams:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.SimpleParams); err != nil {
return err
}
case *PayloadConfig_ComplexParams:
b.EncodeVarint(3<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ComplexParams); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("PayloadConfig.Payload has unexpected type %T", x)
}
return nil
}
func _PayloadConfig_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*PayloadConfig)
switch tag {
case 1: // payload.bytebuf_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ByteBufferParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_BytebufParams{msg}
return true, err
case 2: // payload.simple_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(SimpleProtoParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_SimpleParams{msg}
return true, err
case 3: // payload.complex_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ComplexProtoParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_ComplexParams{msg}
return true, err
default:
return false, nil
}
}
func _PayloadConfig_OneofSizer(msg proto.Message) (n int) {
m := msg.(*PayloadConfig)
// payload
switch x := m.Payload.(type) {
case *PayloadConfig_BytebufParams:
s := proto.Size(x.BytebufParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_SimpleParams:
s := proto.Size(x.SimpleParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_ComplexParams:
s := proto.Size(x.ComplexParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
func init() {
proto.RegisterType((*ByteBufferParams)(nil), "grpc.testing.ByteBufferParams")
proto.RegisterType((*SimpleProtoParams)(nil), "grpc.testing.SimpleProtoParams")
proto.RegisterType((*ComplexProtoParams)(nil), "grpc.testing.ComplexProtoParams")
proto.RegisterType((*PayloadConfig)(nil), "grpc.testing.PayloadConfig")
}
func init() { proto.RegisterFile("payloads.proto", fileDescriptor_payloads_3abc71de35f06c83) }
var fileDescriptor_payloads_3abc71de35f06c83 = []byte{
// 254 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2b, 0x48, 0xac, 0xcc,
0xc9, 0x4f, 0x4c, 0x29, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48,
0xd6, 0x2b, 0x49, 0x2d, 0x2e, 0xc9, 0xcc, 0x4b, 0x57, 0xf2, 0xe2, 0x12, 0x70, 0xaa, 0x2c, 0x49,
0x75, 0x2a, 0x4d, 0x4b, 0x4b, 0x2d, 0x0a, 0x48, 0x2c, 0x4a, 0xcc, 0x2d, 0x16, 0x92, 0xe4, 0xe2,
0x28, 0x4a, 0x2d, 0x8c, 0x2f, 0xce, 0xac, 0x4a, 0x95, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x0d, 0x62,
0x2f, 0x4a, 0x2d, 0x0c, 0xce, 0xac, 0x4a, 0x15, 0x92, 0xe6, 0xe2, 0x2c, 0x4a, 0x2d, 0x2e, 0x80,
0xc8, 0x31, 0x81, 0xe5, 0x38, 0x40, 0x02, 0x20, 0x49, 0x25, 0x6f, 0x2e, 0xc1, 0xe0, 0xcc, 0xdc,
0x82, 0x9c, 0xd4, 0x00, 0x90, 0x45, 0x14, 0x1a, 0x26, 0xc2, 0x25, 0xe4, 0x9c, 0x0f, 0x32, 0xac,
0x02, 0xc9, 0x34, 0xa5, 0x6f, 0x8c, 0x5c, 0xbc, 0x01, 0x10, 0xff, 0x38, 0xe7, 0xe7, 0xa5, 0x65,
0xa6, 0x0b, 0xb9, 0x73, 0xf1, 0x25, 0x55, 0x96, 0xa4, 0x26, 0x95, 0xa6, 0xc5, 0x17, 0x80, 0xd5,
0x80, 0x6d, 0xe1, 0x36, 0x92, 0xd3, 0x43, 0xf6, 0xa7, 0x1e, 0xba, 0x27, 0x3d, 0x18, 0x82, 0x78,
0xa1, 0xfa, 0xa0, 0x0e, 0x75, 0xe3, 0xe2, 0x2d, 0x06, 0xbb, 0x1e, 0x66, 0x0e, 0x13, 0xd8, 0x1c,
0x79, 0x54, 0x73, 0x30, 0x3c, 0xe8, 0xc1, 0x10, 0xc4, 0x03, 0xd1, 0x07, 0x35, 0xc7, 0x93, 0x8b,
0x2f, 0x19, 0xe2, 0x70, 0x98, 0x41, 0xcc, 0x60, 0x83, 0x14, 0x50, 0x0d, 0xc2, 0xf4, 0x1c, 0xc8,
0x49, 0x50, 0x9d, 0x10, 0x01, 0x27, 0x4e, 0x2e, 0x76, 0x68, 0xe4, 0x25, 0xb1, 0x81, 0x23, 0xcf,
0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0xb0, 0x8c, 0x18, 0x4e, 0xce, 0x01, 0x00, 0x00,
}

View File

@ -1,40 +0,0 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package grpc.testing;
message ByteBufferParams {
int32 req_size = 1;
int32 resp_size = 2;
}
message SimpleProtoParams {
int32 req_size = 1;
int32 resp_size = 2;
}
message ComplexProtoParams {
// TODO (vpai): Fill this in once the details of complex, representative
// protos are decided
}
message PayloadConfig {
oneof payload {
ByteBufferParams bytebuf_params = 1;
SimpleProtoParams simple_params = 2;
ComplexProtoParams complex_params = 3;
}
}

View File

@ -1,448 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: services.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// BenchmarkServiceClient is the client API for BenchmarkService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type BenchmarkServiceClient interface {
// One request followed by one response.
// The server returns the client payload as-is.
UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error)
// One request followed by one response.
// The server returns the client payload as-is.
StreamingCall(ctx context.Context, opts ...grpc.CallOption) (BenchmarkService_StreamingCallClient, error)
}
type benchmarkServiceClient struct {
cc *grpc.ClientConn
}
func NewBenchmarkServiceClient(cc *grpc.ClientConn) BenchmarkServiceClient {
return &benchmarkServiceClient{cc}
}
func (c *benchmarkServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) {
out := new(SimpleResponse)
err := c.cc.Invoke(ctx, "/grpc.testing.BenchmarkService/UnaryCall", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *benchmarkServiceClient) StreamingCall(ctx context.Context, opts ...grpc.CallOption) (BenchmarkService_StreamingCallClient, error) {
stream, err := c.cc.NewStream(ctx, &_BenchmarkService_serviceDesc.Streams[0], "/grpc.testing.BenchmarkService/StreamingCall", opts...)
if err != nil {
return nil, err
}
x := &benchmarkServiceStreamingCallClient{stream}
return x, nil
}
type BenchmarkService_StreamingCallClient interface {
Send(*SimpleRequest) error
Recv() (*SimpleResponse, error)
grpc.ClientStream
}
type benchmarkServiceStreamingCallClient struct {
grpc.ClientStream
}
func (x *benchmarkServiceStreamingCallClient) Send(m *SimpleRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *benchmarkServiceStreamingCallClient) Recv() (*SimpleResponse, error) {
m := new(SimpleResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// BenchmarkServiceServer is the server API for BenchmarkService service.
type BenchmarkServiceServer interface {
// One request followed by one response.
// The server returns the client payload as-is.
UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error)
// One request followed by one response.
// The server returns the client payload as-is.
StreamingCall(BenchmarkService_StreamingCallServer) error
}
func RegisterBenchmarkServiceServer(s *grpc.Server, srv BenchmarkServiceServer) {
s.RegisterService(&_BenchmarkService_serviceDesc, srv)
}
func _BenchmarkService_UnaryCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SimpleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BenchmarkServiceServer).UnaryCall(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.BenchmarkService/UnaryCall",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BenchmarkServiceServer).UnaryCall(ctx, req.(*SimpleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _BenchmarkService_StreamingCall_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(BenchmarkServiceServer).StreamingCall(&benchmarkServiceStreamingCallServer{stream})
}
type BenchmarkService_StreamingCallServer interface {
Send(*SimpleResponse) error
Recv() (*SimpleRequest, error)
grpc.ServerStream
}
type benchmarkServiceStreamingCallServer struct {
grpc.ServerStream
}
func (x *benchmarkServiceStreamingCallServer) Send(m *SimpleResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *benchmarkServiceStreamingCallServer) Recv() (*SimpleRequest, error) {
m := new(SimpleRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _BenchmarkService_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.testing.BenchmarkService",
HandlerType: (*BenchmarkServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "UnaryCall",
Handler: _BenchmarkService_UnaryCall_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "StreamingCall",
Handler: _BenchmarkService_StreamingCall_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "services.proto",
}
// WorkerServiceClient is the client API for WorkerService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type WorkerServiceClient interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunServer(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunServerClient, error)
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunClient(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunClientClient, error)
// Just return the core count - unary call
CoreCount(ctx context.Context, in *CoreRequest, opts ...grpc.CallOption) (*CoreResponse, error)
// Quit this worker
QuitWorker(ctx context.Context, in *Void, opts ...grpc.CallOption) (*Void, error)
}
type workerServiceClient struct {
cc *grpc.ClientConn
}
func NewWorkerServiceClient(cc *grpc.ClientConn) WorkerServiceClient {
return &workerServiceClient{cc}
}
func (c *workerServiceClient) RunServer(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunServerClient, error) {
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[0], "/grpc.testing.WorkerService/RunServer", opts...)
if err != nil {
return nil, err
}
x := &workerServiceRunServerClient{stream}
return x, nil
}
type WorkerService_RunServerClient interface {
Send(*ServerArgs) error
Recv() (*ServerStatus, error)
grpc.ClientStream
}
type workerServiceRunServerClient struct {
grpc.ClientStream
}
func (x *workerServiceRunServerClient) Send(m *ServerArgs) error {
return x.ClientStream.SendMsg(m)
}
func (x *workerServiceRunServerClient) Recv() (*ServerStatus, error) {
m := new(ServerStatus)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *workerServiceClient) RunClient(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunClientClient, error) {
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[1], "/grpc.testing.WorkerService/RunClient", opts...)
if err != nil {
return nil, err
}
x := &workerServiceRunClientClient{stream}
return x, nil
}
type WorkerService_RunClientClient interface {
Send(*ClientArgs) error
Recv() (*ClientStatus, error)
grpc.ClientStream
}
type workerServiceRunClientClient struct {
grpc.ClientStream
}
func (x *workerServiceRunClientClient) Send(m *ClientArgs) error {
return x.ClientStream.SendMsg(m)
}
func (x *workerServiceRunClientClient) Recv() (*ClientStatus, error) {
m := new(ClientStatus)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *workerServiceClient) CoreCount(ctx context.Context, in *CoreRequest, opts ...grpc.CallOption) (*CoreResponse, error) {
out := new(CoreResponse)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/CoreCount", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *workerServiceClient) QuitWorker(ctx context.Context, in *Void, opts ...grpc.CallOption) (*Void, error) {
out := new(Void)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/QuitWorker", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// WorkerServiceServer is the server API for WorkerService service.
type WorkerServiceServer interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunServer(WorkerService_RunServerServer) error
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunClient(WorkerService_RunClientServer) error
// Just return the core count - unary call
CoreCount(context.Context, *CoreRequest) (*CoreResponse, error)
// Quit this worker
QuitWorker(context.Context, *Void) (*Void, error)
}
func RegisterWorkerServiceServer(s *grpc.Server, srv WorkerServiceServer) {
s.RegisterService(&_WorkerService_serviceDesc, srv)
}
func _WorkerService_RunServer_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(WorkerServiceServer).RunServer(&workerServiceRunServerServer{stream})
}
type WorkerService_RunServerServer interface {
Send(*ServerStatus) error
Recv() (*ServerArgs, error)
grpc.ServerStream
}
type workerServiceRunServerServer struct {
grpc.ServerStream
}
func (x *workerServiceRunServerServer) Send(m *ServerStatus) error {
return x.ServerStream.SendMsg(m)
}
func (x *workerServiceRunServerServer) Recv() (*ServerArgs, error) {
m := new(ServerArgs)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func _WorkerService_RunClient_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(WorkerServiceServer).RunClient(&workerServiceRunClientServer{stream})
}
type WorkerService_RunClientServer interface {
Send(*ClientStatus) error
Recv() (*ClientArgs, error)
grpc.ServerStream
}
type workerServiceRunClientServer struct {
grpc.ServerStream
}
func (x *workerServiceRunClientServer) Send(m *ClientStatus) error {
return x.ServerStream.SendMsg(m)
}
func (x *workerServiceRunClientServer) Recv() (*ClientArgs, error) {
m := new(ClientArgs)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func _WorkerService_CoreCount_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CoreRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(WorkerServiceServer).CoreCount(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.WorkerService/CoreCount",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(WorkerServiceServer).CoreCount(ctx, req.(*CoreRequest))
}
return interceptor(ctx, in, info, handler)
}
func _WorkerService_QuitWorker_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Void)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(WorkerServiceServer).QuitWorker(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.WorkerService/QuitWorker",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(WorkerServiceServer).QuitWorker(ctx, req.(*Void))
}
return interceptor(ctx, in, info, handler)
}
var _WorkerService_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.testing.WorkerService",
HandlerType: (*WorkerServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "CoreCount",
Handler: _WorkerService_CoreCount_Handler,
},
{
MethodName: "QuitWorker",
Handler: _WorkerService_QuitWorker_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "RunServer",
Handler: _WorkerService_RunServer_Handler,
ServerStreams: true,
ClientStreams: true,
},
{
StreamName: "RunClient",
Handler: _WorkerService_RunClient_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "services.proto",
}
func init() { proto.RegisterFile("services.proto", fileDescriptor_services_bf68f4d7cbd0e0a1) }
var fileDescriptor_services_bf68f4d7cbd0e0a1 = []byte{
// 255 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x91, 0xc1, 0x4a, 0xc4, 0x30,
0x10, 0x86, 0xa9, 0x07, 0xa1, 0xc1, 0x2e, 0x92, 0x93, 0x46, 0x1f, 0xc0, 0x53, 0x91, 0xd5, 0x17,
0x70, 0x8b, 0x1e, 0x05, 0xb7, 0xa8, 0xe7, 0x58, 0x87, 0x1a, 0x36, 0xcd, 0xd4, 0x99, 0x89, 0xe0,
0x93, 0xf8, 0x0e, 0x3e, 0xa5, 0xec, 0x66, 0x57, 0xd6, 0x92, 0x9b, 0xc7, 0xf9, 0xbf, 0xe1, 0x23,
0x7f, 0x46, 0xcd, 0x18, 0xe8, 0xc3, 0x75, 0xc0, 0xf5, 0x48, 0x28, 0xa8, 0x8f, 0x7a, 0x1a, 0xbb,
0x5a, 0x80, 0xc5, 0x85, 0xde, 0xcc, 0x06, 0x60, 0xb6, 0xfd, 0x8e, 0x9a, 0xaa, 0xc3, 0x20, 0x84,
0x3e, 0x8d, 0xf3, 0xef, 0x42, 0x1d, 0x2f, 0x20, 0x74, 0x6f, 0x83, 0xa5, 0x55, 0x9b, 0x44, 0xfa,
0x4e, 0x95, 0x8f, 0xc1, 0xd2, 0x67, 0x63, 0xbd, 0xd7, 0x67, 0xf5, 0xbe, 0xaf, 0x6e, 0xdd, 0x30,
0x7a, 0x58, 0xc2, 0x7b, 0x04, 0x16, 0x73, 0x9e, 0x87, 0x3c, 0x62, 0x60, 0xd0, 0xf7, 0xaa, 0x6a,
0x85, 0xc0, 0x0e, 0x2e, 0xf4, 0xff, 0x74, 0x5d, 0x14, 0x97, 0xc5, 0xfc, 0xeb, 0x40, 0x55, 0xcf,
0x48, 0x2b, 0xa0, 0xdd, 0x4b, 0x6f, 0x55, 0xb9, 0x8c, 0x61, 0x3d, 0x01, 0xe9, 0x93, 0x89, 0x60,
0x93, 0xde, 0x50, 0xcf, 0xc6, 0xe4, 0x48, 0x2b, 0x56, 0x22, 0xaf, 0xc5, 0x5b, 0x4d, 0xe3, 0x1d,
0x04, 0x99, 0x6a, 0x52, 0x9a, 0xd3, 0x24, 0xb2, 0xa7, 0x59, 0xa8, 0xb2, 0x41, 0x82, 0x06, 0x63,
0x10, 0x7d, 0x3a, 0x59, 0x46, 0xfa, 0x6d, 0x6a, 0x72, 0x68, 0xfb, 0x67, 0xd7, 0x4a, 0x3d, 0x44,
0x27, 0xa9, 0xa6, 0xd6, 0x7f, 0x37, 0x9f, 0xd0, 0xbd, 0x9a, 0x4c, 0xf6, 0x72, 0xb8, 0xb9, 0xe6,
0xd5, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x3b, 0x84, 0x02, 0xe3, 0x0c, 0x02, 0x00, 0x00,
}

View File

@ -1,56 +0,0 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// An integration test service that covers all the method signature permutations
// of unary/streaming requests/responses.
syntax = "proto3";
import "messages.proto";
import "control.proto";
package grpc.testing;
service BenchmarkService {
// One request followed by one response.
// The server returns the client payload as-is.
rpc UnaryCall(SimpleRequest) returns (SimpleResponse);
// One request followed by one response.
// The server returns the client payload as-is.
rpc StreamingCall(stream SimpleRequest) returns (stream SimpleResponse);
}
service WorkerService {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
rpc RunServer(stream ServerArgs) returns (stream ServerStatus);
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
rpc RunClient(stream ClientArgs) returns (stream ClientStatus);
// Just return the core count - unary call
rpc CoreCount(CoreRequest) returns (CoreResponse);
// Quit this worker
rpc QuitWorker(Void) returns (Void);
}

View File

@ -1,302 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: stats.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ServerStats struct {
// wall clock time change in seconds since last reset
TimeElapsed float64 `protobuf:"fixed64,1,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
// change in user time (in seconds) used by the server since last reset
TimeUser float64 `protobuf:"fixed64,2,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
// change in server time (in seconds) used by the server process and all
// threads since last reset
TimeSystem float64 `protobuf:"fixed64,3,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerStats) Reset() { *m = ServerStats{} }
func (m *ServerStats) String() string { return proto.CompactTextString(m) }
func (*ServerStats) ProtoMessage() {}
func (*ServerStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{0}
}
func (m *ServerStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerStats.Unmarshal(m, b)
}
func (m *ServerStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerStats.Marshal(b, m, deterministic)
}
func (dst *ServerStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerStats.Merge(dst, src)
}
func (m *ServerStats) XXX_Size() int {
return xxx_messageInfo_ServerStats.Size(m)
}
func (m *ServerStats) XXX_DiscardUnknown() {
xxx_messageInfo_ServerStats.DiscardUnknown(m)
}
var xxx_messageInfo_ServerStats proto.InternalMessageInfo
func (m *ServerStats) GetTimeElapsed() float64 {
if m != nil {
return m.TimeElapsed
}
return 0
}
func (m *ServerStats) GetTimeUser() float64 {
if m != nil {
return m.TimeUser
}
return 0
}
func (m *ServerStats) GetTimeSystem() float64 {
if m != nil {
return m.TimeSystem
}
return 0
}
// Histogram params based on grpc/support/histogram.c
type HistogramParams struct {
Resolution float64 `protobuf:"fixed64,1,opt,name=resolution,proto3" json:"resolution,omitempty"`
MaxPossible float64 `protobuf:"fixed64,2,opt,name=max_possible,json=maxPossible,proto3" json:"max_possible,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramParams) Reset() { *m = HistogramParams{} }
func (m *HistogramParams) String() string { return proto.CompactTextString(m) }
func (*HistogramParams) ProtoMessage() {}
func (*HistogramParams) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{1}
}
func (m *HistogramParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramParams.Unmarshal(m, b)
}
func (m *HistogramParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramParams.Marshal(b, m, deterministic)
}
func (dst *HistogramParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramParams.Merge(dst, src)
}
func (m *HistogramParams) XXX_Size() int {
return xxx_messageInfo_HistogramParams.Size(m)
}
func (m *HistogramParams) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramParams.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramParams proto.InternalMessageInfo
func (m *HistogramParams) GetResolution() float64 {
if m != nil {
return m.Resolution
}
return 0
}
func (m *HistogramParams) GetMaxPossible() float64 {
if m != nil {
return m.MaxPossible
}
return 0
}
// Histogram data based on grpc/support/histogram.c
type HistogramData struct {
Bucket []uint32 `protobuf:"varint,1,rep,packed,name=bucket,proto3" json:"bucket,omitempty"`
MinSeen float64 `protobuf:"fixed64,2,opt,name=min_seen,json=minSeen,proto3" json:"min_seen,omitempty"`
MaxSeen float64 `protobuf:"fixed64,3,opt,name=max_seen,json=maxSeen,proto3" json:"max_seen,omitempty"`
Sum float64 `protobuf:"fixed64,4,opt,name=sum,proto3" json:"sum,omitempty"`
SumOfSquares float64 `protobuf:"fixed64,5,opt,name=sum_of_squares,json=sumOfSquares,proto3" json:"sum_of_squares,omitempty"`
Count float64 `protobuf:"fixed64,6,opt,name=count,proto3" json:"count,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramData) Reset() { *m = HistogramData{} }
func (m *HistogramData) String() string { return proto.CompactTextString(m) }
func (*HistogramData) ProtoMessage() {}
func (*HistogramData) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{2}
}
func (m *HistogramData) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramData.Unmarshal(m, b)
}
func (m *HistogramData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramData.Marshal(b, m, deterministic)
}
func (dst *HistogramData) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramData.Merge(dst, src)
}
func (m *HistogramData) XXX_Size() int {
return xxx_messageInfo_HistogramData.Size(m)
}
func (m *HistogramData) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramData.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramData proto.InternalMessageInfo
func (m *HistogramData) GetBucket() []uint32 {
if m != nil {
return m.Bucket
}
return nil
}
func (m *HistogramData) GetMinSeen() float64 {
if m != nil {
return m.MinSeen
}
return 0
}
func (m *HistogramData) GetMaxSeen() float64 {
if m != nil {
return m.MaxSeen
}
return 0
}
func (m *HistogramData) GetSum() float64 {
if m != nil {
return m.Sum
}
return 0
}
func (m *HistogramData) GetSumOfSquares() float64 {
if m != nil {
return m.SumOfSquares
}
return 0
}
func (m *HistogramData) GetCount() float64 {
if m != nil {
return m.Count
}
return 0
}
type ClientStats struct {
// Latency histogram. Data points are in nanoseconds.
Latencies *HistogramData `protobuf:"bytes,1,opt,name=latencies,proto3" json:"latencies,omitempty"`
// See ServerStats for details.
TimeElapsed float64 `protobuf:"fixed64,2,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
TimeUser float64 `protobuf:"fixed64,3,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
TimeSystem float64 `protobuf:"fixed64,4,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetLatencies() *HistogramData {
if m != nil {
return m.Latencies
}
return nil
}
func (m *ClientStats) GetTimeElapsed() float64 {
if m != nil {
return m.TimeElapsed
}
return 0
}
func (m *ClientStats) GetTimeUser() float64 {
if m != nil {
return m.TimeUser
}
return 0
}
func (m *ClientStats) GetTimeSystem() float64 {
if m != nil {
return m.TimeSystem
}
return 0
}
func init() {
proto.RegisterType((*ServerStats)(nil), "grpc.testing.ServerStats")
proto.RegisterType((*HistogramParams)(nil), "grpc.testing.HistogramParams")
proto.RegisterType((*HistogramData)(nil), "grpc.testing.HistogramData")
proto.RegisterType((*ClientStats)(nil), "grpc.testing.ClientStats")
}
func init() { proto.RegisterFile("stats.proto", fileDescriptor_stats_8ba831c0cb3c3440) }
var fileDescriptor_stats_8ba831c0cb3c3440 = []byte{
// 341 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
0x14, 0x86, 0x49, 0xd3, 0xf6, 0xb6, 0x27, 0xed, 0xbd, 0x97, 0x41, 0x24, 0x52, 0xd0, 0x1a, 0x5c,
0x74, 0x95, 0x85, 0xae, 0x5c, 0xab, 0xe0, 0xce, 0xd2, 0xe8, 0x3a, 0x4c, 0xe3, 0x69, 0x19, 0xcc,
0xcc, 0xc4, 0x39, 0x33, 0x12, 0x1f, 0x49, 0x7c, 0x49, 0xc9, 0x24, 0x68, 0x55, 0xd0, 0x5d, 0xe6,
0xfb, 0x7e, 0xe6, 0xe4, 0xe4, 0x0f, 0x44, 0x64, 0xb9, 0xa5, 0xb4, 0x32, 0xda, 0x6a, 0x36, 0xd9,
0x9a, 0xaa, 0x48, 0x2d, 0x92, 0x15, 0x6a, 0x9b, 0x28, 0x88, 0x32, 0x34, 0x4f, 0x68, 0xb2, 0x26,
0xc2, 0x8e, 0x61, 0x62, 0x85, 0xc4, 0x1c, 0x4b, 0x5e, 0x11, 0xde, 0xc7, 0xc1, 0x3c, 0x58, 0x04,
0xab, 0xa8, 0x61, 0x57, 0x2d, 0x62, 0x33, 0x18, 0xfb, 0x88, 0x23, 0x34, 0x71, 0xcf, 0xfb, 0x51,
0x03, 0xee, 0x08, 0x0d, 0x3b, 0x02, 0x9f, 0xcd, 0xe9, 0x99, 0x2c, 0xca, 0x38, 0xf4, 0x1a, 0x1a,
0x94, 0x79, 0x92, 0xdc, 0xc2, 0xbf, 0x6b, 0x41, 0x56, 0x6f, 0x0d, 0x97, 0x4b, 0x6e, 0xb8, 0x24,
0x76, 0x08, 0x60, 0x90, 0x74, 0xe9, 0xac, 0xd0, 0xaa, 0x9b, 0xb8, 0x43, 0x9a, 0x77, 0x92, 0xbc,
0xce, 0x2b, 0x4d, 0x24, 0xd6, 0x25, 0x76, 0x33, 0x23, 0xc9, 0xeb, 0x65, 0x87, 0x92, 0xd7, 0x00,
0xa6, 0xef, 0xd7, 0x5e, 0x72, 0xcb, 0xd9, 0x3e, 0x0c, 0xd7, 0xae, 0x78, 0x40, 0x1b, 0x07, 0xf3,
0x70, 0x31, 0x5d, 0x75, 0x27, 0x76, 0x00, 0x23, 0x29, 0x54, 0x4e, 0x88, 0xaa, 0xbb, 0xe8, 0x8f,
0x14, 0x2a, 0x43, 0x54, 0x5e, 0xf1, 0xba, 0x55, 0x61, 0xa7, 0x78, 0xed, 0xd5, 0x7f, 0x08, 0xc9,
0xc9, 0xb8, 0xef, 0x69, 0xf3, 0xc8, 0x4e, 0xe0, 0x2f, 0x39, 0x99, 0xeb, 0x4d, 0x4e, 0x8f, 0x8e,
0x1b, 0xa4, 0x78, 0xe0, 0xe5, 0x84, 0x9c, 0xbc, 0xd9, 0x64, 0x2d, 0x63, 0x7b, 0x30, 0x28, 0xb4,
0x53, 0x36, 0x1e, 0x7a, 0xd9, 0x1e, 0x92, 0x97, 0x00, 0xa2, 0x8b, 0x52, 0xa0, 0xb2, 0xed, 0x47,
0x3f, 0x87, 0x71, 0xc9, 0x2d, 0xaa, 0x42, 0x20, 0xf9, 0xfd, 0xa3, 0xd3, 0x59, 0xba, 0xdb, 0x52,
0xfa, 0x69, 0xb7, 0xd5, 0x47, 0xfa, 0x5b, 0x5f, 0xbd, 0x5f, 0xfa, 0x0a, 0x7f, 0xee, 0xab, 0xff,
0xb5, 0xaf, 0xf5, 0xd0, 0xff, 0x34, 0x67, 0x6f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xea, 0x75, 0x34,
0x90, 0x43, 0x02, 0x00, 0x00,
}

View File

@ -1,55 +0,0 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package grpc.testing;
message ServerStats {
// wall clock time change in seconds since last reset
double time_elapsed = 1;
// change in user time (in seconds) used by the server since last reset
double time_user = 2;
// change in server time (in seconds) used by the server process and all
// threads since last reset
double time_system = 3;
}
// Histogram params based on grpc/support/histogram.c
message HistogramParams {
double resolution = 1; // first bucket is [0, 1 + resolution)
double max_possible = 2; // use enough buckets to allow this value
}
// Histogram data based on grpc/support/histogram.c
message HistogramData {
repeated uint32 bucket = 1;
double min_seen = 2;
double max_seen = 3;
double sum = 4;
double sum_of_squares = 5;
double count = 6;
}
message ClientStats {
// Latency histogram. Data points are in nanoseconds.
HistogramData latencies = 1;
// See ServerStats for details.
double time_elapsed = 2;
double time_user = 3;
double time_system = 4;
}

View File

@ -1,315 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package latency provides wrappers for net.Conn, net.Listener, and
// net.Dialers, designed to interoperate to inject real-world latency into
// network connections.
package latency
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"net"
"time"
)
// Dialer is a function matching the signature of net.Dial.
type Dialer func(network, address string) (net.Conn, error)
// TimeoutDialer is a function matching the signature of net.DialTimeout.
type TimeoutDialer func(network, address string, timeout time.Duration) (net.Conn, error)
// ContextDialer is a function matching the signature of
// net.Dialer.DialContext.
type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error)
// Network represents a network with the given bandwidth, latency, and MTU
// (Maximum Transmission Unit) configuration, and can produce wrappers of
// net.Listeners, net.Conn, and various forms of dialing functions. The
// Listeners and Dialers/Conns on both sides of connections must come from this
// package, but need not be created from the same Network. Latency is computed
// when sending (in Write), and is injected when receiving (in Read). This
// allows senders' Write calls to be non-blocking, as in real-world
// applications.
//
// Note: Latency is injected by the sender specifying the absolute time data
// should be available, and the reader delaying until that time arrives to
// provide the data. This package attempts to counter-act the effects of clock
// drift and existing network latency by measuring the delay between the
// sender's transmission time and the receiver's reception time during startup.
// No attempt is made to measure the existing bandwidth of the connection.
type Network struct {
Kbps int // Kilobits per second; if non-positive, infinite
Latency time.Duration // One-way latency (sending); if non-positive, no delay
MTU int // Bytes per packet; if non-positive, infinite
}
var (
//Local simulates local network.
Local = Network{0, 0, 0}
//LAN simulates local area network network.
LAN = Network{100 * 1024, 2 * time.Millisecond, 1500}
//WAN simulates wide area network.
WAN = Network{20 * 1024, 30 * time.Millisecond, 1500}
//Longhaul simulates bad network.
Longhaul = Network{1000 * 1024, 200 * time.Millisecond, 9000}
)
// Conn returns a net.Conn that wraps c and injects n's latency into that
// connection. This function also imposes latency for connection creation.
// If n's Latency is lower than the measured latency in c, an error is
// returned.
func (n *Network) Conn(c net.Conn) (net.Conn, error) {
start := now()
nc := &conn{Conn: c, network: n, readBuf: new(bytes.Buffer)}
if err := nc.sync(); err != nil {
return nil, err
}
sleep(start.Add(nc.delay).Sub(now()))
return nc, nil
}
type conn struct {
net.Conn
network *Network
readBuf *bytes.Buffer // one packet worth of data received
lastSendEnd time.Time // time the previous Write should be fully on the wire
delay time.Duration // desired latency - measured latency
}
// header is sent before all data transmitted by the application.
type header struct {
ReadTime int64 // Time the reader is allowed to read this packet (UnixNano)
Sz int32 // Size of the data in the packet
}
func (c *conn) Write(p []byte) (n int, err error) {
tNow := now()
if c.lastSendEnd.Before(tNow) {
c.lastSendEnd = tNow
}
for len(p) > 0 {
pkt := p
if c.network.MTU > 0 && len(pkt) > c.network.MTU {
pkt = pkt[:c.network.MTU]
p = p[c.network.MTU:]
} else {
p = nil
}
if c.network.Kbps > 0 {
if congestion := c.lastSendEnd.Sub(tNow) - c.delay; congestion > 0 {
// The network is full; sleep until this packet can be sent.
sleep(congestion)
tNow = tNow.Add(congestion)
}
}
c.lastSendEnd = c.lastSendEnd.Add(c.network.pktTime(len(pkt)))
hdr := header{ReadTime: c.lastSendEnd.Add(c.delay).UnixNano(), Sz: int32(len(pkt))}
if err := binary.Write(c.Conn, binary.BigEndian, hdr); err != nil {
return n, err
}
x, err := c.Conn.Write(pkt)
n += x
if err != nil {
return n, err
}
}
return n, nil
}
func (c *conn) Read(p []byte) (n int, err error) {
if c.readBuf.Len() == 0 {
var hdr header
if err := binary.Read(c.Conn, binary.BigEndian, &hdr); err != nil {
return 0, err
}
defer func() { sleep(time.Unix(0, hdr.ReadTime).Sub(now())) }()
if _, err := io.CopyN(c.readBuf, c.Conn, int64(hdr.Sz)); err != nil {
return 0, err
}
}
// Read from readBuf.
return c.readBuf.Read(p)
}
// sync does a handshake and then measures the latency on the network in
// coordination with the other side.
func (c *conn) sync() error {
const (
pingMsg = "syncPing"
warmup = 10 // minimum number of iterations to measure latency
giveUp = 50 // maximum number of iterations to measure latency
accuracy = time.Millisecond // req'd accuracy to stop early
goodRun = 3 // stop early if latency within accuracy this many times
)
type syncMsg struct {
SendT int64 // Time sent. If zero, stop.
RecvT int64 // Time received. If zero, fill in and respond.
}
// A trivial handshake
if err := binary.Write(c.Conn, binary.BigEndian, []byte(pingMsg)); err != nil {
return err
}
var ping [8]byte
if err := binary.Read(c.Conn, binary.BigEndian, &ping); err != nil {
return err
} else if string(ping[:]) != pingMsg {
return fmt.Errorf("malformed handshake message: %v (want %q)", ping, pingMsg)
}
// Both sides are alive and syncing. Calculate network delay / clock skew.
att := 0
good := 0
var latency time.Duration
localDone, remoteDone := false, false
send := true
for !localDone || !remoteDone {
if send {
if err := binary.Write(c.Conn, binary.BigEndian, syncMsg{SendT: now().UnixNano()}); err != nil {
return err
}
att++
send = false
}
// Block until we get a syncMsg
m := syncMsg{}
if err := binary.Read(c.Conn, binary.BigEndian, &m); err != nil {
return err
}
if m.RecvT == 0 {
// Message initiated from other side.
if m.SendT == 0 {
remoteDone = true
continue
}
// Send response.
m.RecvT = now().UnixNano()
if err := binary.Write(c.Conn, binary.BigEndian, m); err != nil {
return err
}
continue
}
lag := time.Duration(m.RecvT - m.SendT)
latency += lag
avgLatency := latency / time.Duration(att)
if e := lag - avgLatency; e > -accuracy && e < accuracy {
good++
} else {
good = 0
}
if att < giveUp && (att < warmup || good < goodRun) {
send = true
continue
}
localDone = true
latency = avgLatency
// Tell the other side we're done.
if err := binary.Write(c.Conn, binary.BigEndian, syncMsg{}); err != nil {
return err
}
}
if c.network.Latency <= 0 {
return nil
}
c.delay = c.network.Latency - latency
if c.delay < 0 {
return fmt.Errorf("measured network latency (%v) higher than desired latency (%v)", latency, c.network.Latency)
}
return nil
}
// Listener returns a net.Listener that wraps l and injects n's latency in its
// connections.
func (n *Network) Listener(l net.Listener) net.Listener {
return &listener{Listener: l, network: n}
}
type listener struct {
net.Listener
network *Network
}
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return l.network.Conn(c)
}
// Dialer returns a Dialer that wraps d and injects n's latency in its
// connections. n's Latency is also injected to the connection's creation.
func (n *Network) Dialer(d Dialer) Dialer {
return func(network, address string) (net.Conn, error) {
conn, err := d(network, address)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// TimeoutDialer returns a TimeoutDialer that wraps d and injects n's latency
// in its connections. n's Latency is also injected to the connection's
// creation.
func (n *Network) TimeoutDialer(d TimeoutDialer) TimeoutDialer {
return func(network, address string, timeout time.Duration) (net.Conn, error) {
conn, err := d(network, address, timeout)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// ContextDialer returns a ContextDialer that wraps d and injects n's latency
// in its connections. n's Latency is also injected to the connection's
// creation.
func (n *Network) ContextDialer(d ContextDialer) ContextDialer {
return func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d(ctx, network, address)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// pktTime returns the time it takes to transmit one packet of data of size b
// in bytes.
func (n *Network) pktTime(b int) time.Duration {
if n.Kbps <= 0 {
return time.Duration(0)
}
return time.Duration(b) * time.Second / time.Duration(n.Kbps*(1024/8))
}
// Wrappers for testing
var now = time.Now
var sleep = time.Sleep

View File

@ -1,353 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package latency
import (
"bytes"
"fmt"
"net"
"reflect"
"sync"
"testing"
"time"
)
// bufConn is a net.Conn implemented by a bytes.Buffer (which is a ReadWriter).
type bufConn struct {
*bytes.Buffer
}
func (bufConn) Close() error { panic("unimplemented") }
func (bufConn) LocalAddr() net.Addr { panic("unimplemented") }
func (bufConn) RemoteAddr() net.Addr { panic("unimplemented") }
func (bufConn) SetDeadline(t time.Time) error { panic("unimplemneted") }
func (bufConn) SetReadDeadline(t time.Time) error { panic("unimplemneted") }
func (bufConn) SetWriteDeadline(t time.Time) error { panic("unimplemneted") }
func restoreHooks() func() {
s := sleep
n := now
return func() {
sleep = s
now = n
}
}
func TestConn(t *testing.T) {
defer restoreHooks()()
// Constant time.
now = func() time.Time { return time.Unix(123, 456) }
// Capture sleep times for checking later.
var sleepTimes []time.Duration
sleep = func(t time.Duration) { sleepTimes = append(sleepTimes, t) }
wantSleeps := func(want ...time.Duration) {
if !reflect.DeepEqual(want, sleepTimes) {
t.Fatalf("sleepTimes = %v; want %v", sleepTimes, want)
}
sleepTimes = nil
}
// Use a fairly high latency to cause a large BDP and avoid sleeps while
// writing due to simulation of full buffers.
latency := 1 * time.Second
c, err := (&Network{Kbps: 1, Latency: latency, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
wantSleeps(latency) // Connection creation delay.
// 1 kbps = 128 Bps. Divides evenly by 1 second using nanos.
byteLatency := time.Duration(time.Second / 128)
write := func(b []byte) {
n, err := c.Write(b)
if n != len(b) || err != nil {
t.Fatalf("c.Write(%v) = %v, %v; want %v, nil", b, n, err, len(b))
}
}
write([]byte{1, 2, 3, 4, 5}) // One full packet
pkt1Time := latency + byteLatency*5
write([]byte{6}) // One partial packet
pkt2Time := pkt1Time + byteLatency
write([]byte{7, 8, 9, 10, 11, 12, 13}) // Two packets
pkt3Time := pkt2Time + byteLatency*5
pkt4Time := pkt3Time + byteLatency*2
// No reads, so no sleeps yet.
wantSleeps()
read := func(n int, want []byte) {
b := make([]byte, n)
if rd, err := c.Read(b); err != nil || rd != len(want) {
t.Fatalf("c.Read(<%v bytes>) = %v, %v; want %v, nil", n, rd, err, len(want))
}
if !reflect.DeepEqual(b[:len(want)], want) {
t.Fatalf("read %v; want %v", b, want)
}
}
read(1, []byte{1})
wantSleeps(pkt1Time)
read(1, []byte{2})
wantSleeps()
read(3, []byte{3, 4, 5})
wantSleeps()
read(2, []byte{6})
wantSleeps(pkt2Time)
read(2, []byte{7, 8})
wantSleeps(pkt3Time)
read(10, []byte{9, 10, 11})
wantSleeps()
read(10, []byte{12, 13})
wantSleeps(pkt4Time)
}
func TestSync(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
sleep = func(d time.Duration) { tn = tn.Add(d) }
// Simulate a 20ms latency network, then run sync across that and expect to
// measure 20ms latency, or 10ms additional delay for a 30ms network.
slowConn, err := (&Network{Kbps: 0, Latency: 20 * time.Millisecond, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
c, err := (&Network{Latency: 30 * time.Millisecond}).Conn(slowConn)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
if c.(*conn).delay != 10*time.Millisecond {
t.Fatalf("c.delay = %v; want 10ms", c.(*conn).delay)
}
}
func TestSyncTooSlow(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
sleep = func(d time.Duration) { tn = tn.Add(d) }
// Simulate a 10ms latency network, then attempt to simulate a 5ms latency
// network and expect an error.
slowConn, err := (&Network{Kbps: 0, Latency: 10 * time.Millisecond, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
errWant := "measured network latency (10ms) higher than desired latency (5ms)"
if _, err := (&Network{Latency: 5 * time.Millisecond}).Conn(slowConn); err == nil || err.Error() != errWant {
t.Fatalf("Conn() = _, %q; want _, %q", err, errWant)
}
}
func TestListenerAndDialer(t *testing.T) {
defer restoreHooks()()
tn := time.Unix(123, 0)
startTime := tn
mu := &sync.Mutex{}
now = func() time.Time {
mu.Lock()
defer mu.Unlock()
return tn
}
// Use a fairly high latency to cause a large BDP and avoid sleeps while
// writing due to simulation of full buffers.
n := &Network{Kbps: 2, Latency: 1 * time.Second, MTU: 10}
// 2 kbps = .25 kBps = 256 Bps
byteLatency := func(n int) time.Duration {
return time.Duration(n) * time.Second / 256
}
// Create a real listener and wrap it.
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Unexpected error creating listener: %v", err)
}
defer l.Close()
l = n.Listener(l)
var serverConn net.Conn
var scErr error
scDone := make(chan struct{})
go func() {
serverConn, scErr = l.Accept()
close(scDone)
}()
// Create a dialer and use it.
clientConn, err := n.TimeoutDialer(net.DialTimeout)("tcp", l.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("Unexpected error dialing: %v", err)
}
defer clientConn.Close()
// Block until server's Conn is available.
<-scDone
if scErr != nil {
t.Fatalf("Unexpected error listening: %v", scErr)
}
defer serverConn.Close()
// sleep (only) advances tn. Done after connections established so sync detects zero delay.
sleep = func(d time.Duration) {
mu.Lock()
defer mu.Unlock()
if d > 0 {
tn = tn.Add(d)
}
}
seq := func(a, b int) []byte {
buf := make([]byte, b-a)
for i := 0; i < b-a; i++ {
buf[i] = byte(i + a)
}
return buf
}
pkt1 := seq(0, 10)
pkt2 := seq(10, 30)
pkt3 := seq(30, 35)
write := func(c net.Conn, b []byte) {
n, err := c.Write(b)
if n != len(b) || err != nil {
t.Fatalf("c.Write(%v) = %v, %v; want %v, nil", b, n, err, len(b))
}
}
write(serverConn, pkt1)
write(serverConn, pkt2)
write(serverConn, pkt3)
write(clientConn, pkt3)
write(clientConn, pkt1)
write(clientConn, pkt2)
if tn != startTime {
t.Fatalf("unexpected sleep in write; tn = %v; want %v", tn, startTime)
}
read := func(c net.Conn, n int, want []byte, timeWant time.Time) {
b := make([]byte, n)
if rd, err := c.Read(b); err != nil || rd != len(want) {
t.Fatalf("c.Read(<%v bytes>) = %v, %v; want %v, nil (read: %v)", n, rd, err, len(want), b[:rd])
}
if !reflect.DeepEqual(b[:len(want)], want) {
t.Fatalf("read %v; want %v", b, want)
}
if !tn.Equal(timeWant) {
t.Errorf("tn after read(%v) = %v; want %v", want, tn, timeWant)
}
}
read(clientConn, len(pkt1)+1, pkt1, startTime.Add(n.Latency+byteLatency(len(pkt1))))
read(serverConn, len(pkt3)+1, pkt3, tn) // tn was advanced by the above read; pkt3 is shorter than pkt1
read(clientConn, len(pkt2), pkt2[:10], startTime.Add(n.Latency+byteLatency(len(pkt1)+10)))
read(clientConn, len(pkt2), pkt2[10:], startTime.Add(n.Latency+byteLatency(len(pkt1)+len(pkt2))))
read(clientConn, len(pkt3), pkt3, startTime.Add(n.Latency+byteLatency(len(pkt1)+len(pkt2)+len(pkt3))))
read(serverConn, len(pkt1), pkt1, tn) // tn already past the arrival time due to prior reads
read(serverConn, len(pkt2), pkt2[:10], tn)
read(serverConn, len(pkt2), pkt2[10:], tn)
// Sleep awhile and make sure the read happens disregarding previous writes
// (lastSendEnd handling).
sleep(10 * time.Second)
write(clientConn, pkt1)
read(serverConn, len(pkt1), pkt1, tn.Add(n.Latency+byteLatency(len(pkt1))))
// Send, sleep longer than the network delay, then make sure the read happens
// instantly.
write(serverConn, pkt1)
sleep(10 * time.Second)
read(clientConn, len(pkt1), pkt1, tn)
}
func TestBufferBloat(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
// Capture sleep times for checking later.
var sleepTimes []time.Duration
sleep = func(d time.Duration) {
sleepTimes = append(sleepTimes, d)
tn = tn.Add(d)
}
wantSleeps := func(want ...time.Duration) error {
if !reflect.DeepEqual(want, sleepTimes) {
return fmt.Errorf("sleepTimes = %v; want %v", sleepTimes, want)
}
sleepTimes = nil
return nil
}
n := &Network{Kbps: 8 /* 1KBps */, Latency: time.Second, MTU: 8}
bdpBytes := (n.Kbps * 1024 / 8) * int(n.Latency/time.Second) // 1024
c, err := n.Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
wantSleeps(n.Latency) // Connection creation delay.
write := func(n int, sleeps ...time.Duration) {
if wt, err := c.Write(make([]byte, n)); err != nil || wt != n {
t.Fatalf("c.Write(<%v bytes>) = %v, %v; want %v, nil", n, wt, err, n)
}
if err := wantSleeps(sleeps...); err != nil {
t.Fatalf("After writing %v bytes: %v", n, err)
}
}
read := func(n int, sleeps ...time.Duration) {
if rd, err := c.Read(make([]byte, n)); err != nil || rd != n {
t.Fatalf("c.Read(_) = %v, %v; want %v, nil", rd, err, n)
}
if err := wantSleeps(sleeps...); err != nil {
t.Fatalf("After reading %v bytes: %v", n, err)
}
}
write(8) // No reads and buffer not full, so no sleeps yet.
read(8, time.Second+n.pktTime(8))
write(bdpBytes) // Fill the buffer.
write(1) // We can send one extra packet even when the buffer is full.
write(n.MTU, n.pktTime(1)) // Make sure we sleep to clear the previous write.
write(1, n.pktTime(n.MTU))
write(n.MTU+1, n.pktTime(1), n.pktTime(n.MTU))
tn = tn.Add(10 * time.Second) // Wait long enough for the buffer to clear.
write(bdpBytes) // No sleeps required.
}

View File

@ -1,135 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package primitives_test
import (
"strconv"
"testing"
"google.golang.org/grpc/codes"
)
type codeBench uint32
const (
OK codeBench = iota
Canceled
Unknown
InvalidArgument
DeadlineExceeded
NotFound
AlreadyExists
PermissionDenied
ResourceExhausted
FailedPrecondition
Aborted
OutOfRange
Unimplemented
Internal
Unavailable
DataLoss
Unauthenticated
)
// The following String() function was generated by stringer.
const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlreadyExistsPermissionDeniedResourceExhaustedFailedPreconditionAbortedOutOfRangeUnimplementedInternalUnavailableDataLossUnauthenticated"
var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192}
func (i codeBench) String() string {
if i >= codeBench(len(_Code_index)-1) {
return "Code(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Code_name[_Code_index[i]:_Code_index[i+1]]
}
var nameMap = map[codeBench]string{
OK: "OK",
Canceled: "Canceled",
Unknown: "Unknown",
InvalidArgument: "InvalidArgument",
DeadlineExceeded: "DeadlineExceeded",
NotFound: "NotFound",
AlreadyExists: "AlreadyExists",
PermissionDenied: "PermissionDenied",
ResourceExhausted: "ResourceExhausted",
FailedPrecondition: "FailedPrecondition",
Aborted: "Aborted",
OutOfRange: "OutOfRange",
Unimplemented: "Unimplemented",
Internal: "Internal",
Unavailable: "Unavailable",
DataLoss: "DataLoss",
Unauthenticated: "Unauthenticated",
}
func (i codeBench) StringUsingMap() string {
if s, ok := nameMap[i]; ok {
return s
}
return "Code(" + strconv.FormatInt(int64(i), 10) + ")"
}
func BenchmarkCodeStringStringer(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 17))
_ = c.String()
}
b.StopTimer()
}
func BenchmarkCodeStringMap(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 17))
_ = c.StringUsingMap()
}
b.StopTimer()
}
// codes.Code.String() does a switch.
func BenchmarkCodeStringSwitch(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codes.Code(uint32(i % 17))
_ = c.String()
}
b.StopTimer()
}
// Testing all codes (0<=c<=16) and also one overflow (17).
func BenchmarkCodeStringStringerWithOverflow(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 18))
_ = c.String()
}
b.StopTimer()
}
// Testing all codes (0<=c<=16) and also one overflow (17).
func BenchmarkCodeStringSwitchWithOverflow(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codes.Code(uint32(i % 18))
_ = c.String()
}
b.StopTimer()
}

View File

@ -1,119 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package primitives_test
import (
"context"
"testing"
"time"
)
func BenchmarkCancelContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
}
}
cancel()
}
func BenchmarkCancelContextErrGotErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
}
}
func BenchmarkCancelContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
b.Fatal("error: ctx.Done():", ctx.Err())
default:
}
}
cancel()
}
func BenchmarkCancelContextChannelGotErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
default:
b.Fatal("error: !ctx.Done()")
}
}
}
func BenchmarkTimerContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
}
}
cancel()
}
func BenchmarkTimerContextErrGotErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
cancel()
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
}
}
func BenchmarkTimerContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
b.Fatal("error: ctx.Done():", ctx.Err())
default:
}
}
cancel()
}
func BenchmarkTimerContextChannelGotErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
cancel()
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
default:
b.Fatal("error: !ctx.Done()")
}
}
}

View File

@ -1,401 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package primitives_test contains benchmarks for various synchronization primitives
// available in Go.
package primitives_test
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"unsafe"
)
func BenchmarkSelectClosed(b *testing.B) {
c := make(chan struct{})
close(c)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case <-c:
x++
default:
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkSelectOpen(b *testing.B) {
c := make(chan struct{})
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case <-c:
default:
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicBool(b *testing.B) {
c := int32(0)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if atomic.LoadInt32(&c) == 0 {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicValueLoad(b *testing.B) {
c := atomic.Value{}
c.Store(0)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if c.Load().(int) == 0 {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicValueStore(b *testing.B) {
c := atomic.Value{}
v := 123
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(v)
}
b.StopTimer()
}
func BenchmarkMutex(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Lock()
x++
c.Unlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkRWMutex(b *testing.B) {
c := sync.RWMutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.RLock()
x++
c.RUnlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkRWMutexW(b *testing.B) {
c := sync.RWMutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Lock()
x++
c.Unlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
defer c.Unlock()
x++
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithClosureDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
defer func() { c.Unlock() }()
x++
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithoutDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
x++
c.Unlock()
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicAddInt64(b *testing.B) {
var c int64
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.AddInt64(&c, 1)
}
b.StopTimer()
if c != int64(b.N) {
b.Fatal("error")
}
}
func BenchmarkAtomicTimeValueStore(b *testing.B) {
var c atomic.Value
t := time.Now()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomic16BValueStore(b *testing.B) {
var c atomic.Value
t := struct {
a int64
b int64
}{
123, 123,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomic32BValueStore(b *testing.B) {
var c atomic.Value
t := struct {
a int64
b int64
c int64
d int64
}{
123, 123, 123, 123,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomicPointerStore(b *testing.B) {
t := 123
var up unsafe.Pointer
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.StorePointer(&up, unsafe.Pointer(&t))
}
b.StopTimer()
}
func BenchmarkAtomicTimePointerStore(b *testing.B) {
t := time.Now()
var up unsafe.Pointer
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.StorePointer(&up, unsafe.Pointer(&t))
}
b.StopTimer()
}
func BenchmarkStoreContentionWithAtomic(b *testing.B) {
t := 123
var c unsafe.Pointer
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
atomic.StorePointer(&c, unsafe.Pointer(&t))
}
})
}
func BenchmarkStoreContentionWithMutex(b *testing.B) {
t := 123
var mu sync.Mutex
var c int
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
mu.Lock()
c = t
mu.Unlock()
}
})
_ = c
}
type dummyStruct struct {
a int64
b time.Time
}
func BenchmarkStructStoreContention(b *testing.B) {
d := dummyStruct{}
dp := unsafe.Pointer(&d)
t := time.Now()
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("CAS/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
n := &dummyStruct{
b: t,
}
for pb.Next() {
for y := 0; y < j; y++ {
}
for {
v := (*dummyStruct)(atomic.LoadPointer(&dp))
n.a = v.a + 1
if atomic.CompareAndSwapPointer(&dp, unsafe.Pointer(v), unsafe.Pointer(n)) {
n = v
break
}
}
}
})
})
}
}
var mu sync.Mutex
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("Mutex/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for y := 0; y < j; y++ {
}
mu.Lock()
d.a++
d.b = t
mu.Unlock()
}
})
})
}
}
}
type myFooer struct{}
func (myFooer) Foo() {}
type fooer interface {
Foo()
}
func BenchmarkInterfaceTypeAssertion(b *testing.B) {
// Call a separate function to avoid compiler optimizations.
runInterfaceTypeAssertion(b, myFooer{})
}
func runInterfaceTypeAssertion(b *testing.B, fer interface{}) {
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, ok := fer.(fooer); ok {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkStructTypeAssertion(b *testing.B) {
// Call a separate function to avoid compiler optimizations.
runStructTypeAssertion(b, myFooer{})
}
func runStructTypeAssertion(b *testing.B, fer interface{}) {
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, ok := fer.(myFooer); ok {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}

View File

@ -1,187 +0,0 @@
#!/bin/bash
rpcs=(1)
conns=(1)
warmup=10
dur=10
reqs=(1)
resps=(1)
rpc_types=(unary)
# idx[0] = idx value for rpcs
# idx[1] = idx value for conns
# idx[2] = idx value for reqs
# idx[3] = idx value for resps
# idx[4] = idx value for rpc_types
idx=(0 0 0 0 0)
idx_max=(1 1 1 1 1)
inc()
{
for i in $(seq $((${#idx[@]}-1)) -1 0); do
idx[${i}]=$((${idx[${i}]}+1))
if [ ${idx[${i}]} == ${idx_max[${i}]} ]; then
idx[${i}]=0
else
break
fi
done
local fin
fin=1
# Check to see if we have looped back to the beginning.
for v in ${idx[@]}; do
if [ ${v} != 0 ]; then
fin=0
break
fi
done
if [ ${fin} == 1 ]; then
rm -Rf ${out_dir}
clean_and_die 0
fi
}
clean_and_die() {
rm -Rf ${out_dir}
exit $1
}
run(){
local nr
nr=${rpcs[${idx[0]}]}
local nc
nc=${conns[${idx[1]}]}
req_sz=${reqs[${idx[2]}]}
resp_sz=${resps[${idx[3]}]}
r_type=${rpc_types[${idx[4]}]}
# Following runs one benchmark
base_port=50051
delta=0
test_name="r_"${nr}"_c_"${nc}"_req_"${req_sz}"_resp_"${resp_sz}"_"${r_type}"_"$(date +%s)
echo "================================================================================"
echo ${test_name}
while :
do
port=$((${base_port}+${delta}))
# Launch the server in background
${out_dir}/server --port=${port} --test_name="Server_"${test_name}&
server_pid=$(echo $!)
# Launch the client
${out_dir}/client --port=${port} --d=${dur} --w=${warmup} --r=${nr} --c=${nc} --req=${req_sz} --resp=${resp_sz} --rpc_type=${r_type} --test_name="client_"${test_name}
client_status=$(echo $?)
kill -INT ${server_pid}
wait ${server_pid}
if [ ${client_status} == 0 ]; then
break
fi
delta=$((${delta}+1))
if [ ${delta} == 10 ]; then
echo "Continuous 10 failed runs. Exiting now."
rm -Rf ${out_dir}
clean_and_die 1
fi
done
}
set_param(){
local argname=$1
shift
local idx=$1
shift
if [ $# -eq 0 ]; then
echo "${argname} not specified"
exit 1
fi
PARAM=($(echo $1 | sed 's/,/ /g'))
if [ ${idx} -lt 0 ]; then
return
fi
idx_max[${idx}]=${#PARAM[@]}
}
while [ $# -gt 0 ]; do
case "$1" in
-r)
shift
set_param "number of rpcs" 0 $1
rpcs=(${PARAM[@]})
shift
;;
-c)
shift
set_param "number of connections" 1 $1
conns=(${PARAM[@]})
shift
;;
-w)
shift
set_param "warm-up period" -1 $1
warmup=${PARAM}
shift
;;
-d)
shift
set_param "duration" -1 $1
dur=${PARAM}
shift
;;
-req)
shift
set_param "request size" 2 $1
reqs=(${PARAM[@]})
shift
;;
-resp)
shift
set_param "response size" 3 $1
resps=(${PARAM[@]})
shift
;;
-rpc_type)
shift
set_param "rpc type" 4 $1
rpc_types=(${PARAM[@]})
shift
;;
-h|--help)
echo "Following are valid options:"
echo
echo "-h, --help show brief help"
echo "-w warm-up duration in seconds, default value is 10"
echo "-d benchmark duration in seconds, default value is 60"
echo ""
echo "Each of the following can have multiple comma separated values."
echo ""
echo "-r number of RPCs, default value is 1"
echo "-c number of Connections, default value is 1"
echo "-req req size in bytes, default value is 1"
echo "-resp resp size in bytes, default value is 1"
echo "-rpc_type valid values are unary|streaming, default is unary"
;;
*)
echo "Incorrect option $1"
exit 1
;;
esac
done
# Build server and client
out_dir=$(mktemp -d oss_benchXXX)
go build -o ${out_dir}/server $GOPATH/src/google.golang.org/grpc/benchmark/server/main.go && go build -o ${out_dir}/client $GOPATH/src/google.golang.org/grpc/benchmark/client/main.go
if [ $? != 0 ]; then
clean_and_die 1
fi
while :
do
run
inc
done

View File

@ -1,81 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"net"
_ "net/http/pprof"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"time"
"google.golang.org/grpc/benchmark"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var (
port = flag.String("port", "50051", "Localhost port to listen on.")
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
)
func main() {
flag.Parse()
if *testName == "" {
grpclog.Fatalf("test name not set")
}
lis, err := net.Listen("tcp", ":"+*port)
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
// Launch server in a separate goroutine.
stop := benchmark.StartServer(benchmark.ServerInfo{Type: "protobuf", Listener: lis})
// Wait on OS terminate signal.
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
<-ch
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
stop()
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Failed to write memory profile: %v", err)
}
fmt.Println("Server CPU utilization:", cpu)
fmt.Println("Server CPU profile:", cf.Name())
fmt.Println("Server Mem Profile:", mf.Name())
}

View File

@ -1,222 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bytes"
"fmt"
"io"
"log"
"math"
"strconv"
"strings"
)
// Histogram accumulates values in the form of a histogram with
// exponentially increased bucket sizes.
type Histogram struct {
// Count is the total number of values added to the histogram.
Count int64
// Sum is the sum of all the values added to the histogram.
Sum int64
// SumOfSquares is the sum of squares of all values.
SumOfSquares int64
// Min is the minimum of all the values added to the histogram.
Min int64
// Max is the maximum of all the values added to the histogram.
Max int64
// Buckets contains all the buckets of the histogram.
Buckets []HistogramBucket
opts HistogramOptions
logBaseBucketSize float64
oneOverLogOnePlusGrowthFactor float64
}
// HistogramOptions contains the parameters that define the histogram's buckets.
// The first bucket of the created histogram (with index 0) contains [min, min+n)
// where n = BaseBucketSize, min = MinValue.
// Bucket i (i>=1) contains [min + n * m^(i-1), min + n * m^i), where m = 1+GrowthFactor.
// The type of the values is int64.
type HistogramOptions struct {
// NumBuckets is the number of buckets.
NumBuckets int
// GrowthFactor is the growth factor of the buckets. A value of 0.1
// indicates that bucket N+1 will be 10% larger than bucket N.
GrowthFactor float64
// BaseBucketSize is the size of the first bucket.
BaseBucketSize float64
// MinValue is the lower bound of the first bucket.
MinValue int64
}
// HistogramBucket represents one histogram bucket.
type HistogramBucket struct {
// LowBound is the lower bound of the bucket.
LowBound float64
// Count is the number of values in the bucket.
Count int64
}
// NewHistogram returns a pointer to a new Histogram object that was created
// with the provided options.
func NewHistogram(opts HistogramOptions) *Histogram {
if opts.NumBuckets == 0 {
opts.NumBuckets = 32
}
if opts.BaseBucketSize == 0.0 {
opts.BaseBucketSize = 1.0
}
h := Histogram{
Buckets: make([]HistogramBucket, opts.NumBuckets),
Min: math.MaxInt64,
Max: math.MinInt64,
opts: opts,
logBaseBucketSize: math.Log(opts.BaseBucketSize),
oneOverLogOnePlusGrowthFactor: 1 / math.Log(1+opts.GrowthFactor),
}
m := 1.0 + opts.GrowthFactor
delta := opts.BaseBucketSize
h.Buckets[0].LowBound = float64(opts.MinValue)
for i := 1; i < opts.NumBuckets; i++ {
h.Buckets[i].LowBound = float64(opts.MinValue) + delta
delta = delta * m
}
return &h
}
// Print writes textual output of the histogram values.
func (h *Histogram) Print(w io.Writer) {
h.PrintWithUnit(w, 1)
}
// PrintWithUnit writes textual output of the histogram values .
// Data in histogram is divided by a Unit before print.
func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) {
avg := float64(h.Sum) / float64(h.Count)
fmt.Fprintf(w, "Count: %d Min: %5.1f Max: %5.1f Avg: %.2f\n", h.Count, float64(h.Min)/unit, float64(h.Max)/unit, avg/unit)
fmt.Fprintf(w, "%s\n", strings.Repeat("-", 60))
if h.Count <= 0 {
return
}
maxBucketDigitLen := len(strconv.FormatFloat(h.Buckets[len(h.Buckets)-1].LowBound, 'f', 6, 64))
if maxBucketDigitLen < 3 {
// For "inf".
maxBucketDigitLen = 3
}
maxCountDigitLen := len(strconv.FormatInt(h.Count, 10))
percentMulti := 100 / float64(h.Count)
accCount := int64(0)
for i, b := range h.Buckets {
fmt.Fprintf(w, "[%*f, ", maxBucketDigitLen, b.LowBound/unit)
if i+1 < len(h.Buckets) {
fmt.Fprintf(w, "%*f)", maxBucketDigitLen, h.Buckets[i+1].LowBound/unit)
} else {
fmt.Fprintf(w, "%*s)", maxBucketDigitLen, "inf")
}
accCount += b.Count
fmt.Fprintf(w, " %*d %5.1f%% %5.1f%%", maxCountDigitLen, b.Count, float64(b.Count)*percentMulti, float64(accCount)*percentMulti)
const barScale = 0.1
barLength := int(float64(b.Count)*percentMulti*barScale + 0.5)
fmt.Fprintf(w, " %s\n", strings.Repeat("#", barLength))
}
}
// String returns the textual output of the histogram values as string.
func (h *Histogram) String() string {
var b bytes.Buffer
h.Print(&b)
return b.String()
}
// Clear resets all the content of histogram.
func (h *Histogram) Clear() {
h.Count = 0
h.Sum = 0
h.SumOfSquares = 0
h.Min = math.MaxInt64
h.Max = math.MinInt64
for i := range h.Buckets {
h.Buckets[i].Count = 0
}
}
// Opts returns a copy of the options used to create the Histogram.
func (h *Histogram) Opts() HistogramOptions {
return h.opts
}
// Add adds a value to the histogram.
func (h *Histogram) Add(value int64) error {
bucket, err := h.findBucket(value)
if err != nil {
return err
}
h.Buckets[bucket].Count++
h.Count++
h.Sum += value
h.SumOfSquares += value * value
if value < h.Min {
h.Min = value
}
if value > h.Max {
h.Max = value
}
return nil
}
func (h *Histogram) findBucket(value int64) (int, error) {
delta := float64(value - h.opts.MinValue)
var b int
if delta >= h.opts.BaseBucketSize {
// b = log_{1+growthFactor} (delta / baseBucketSize) + 1
// = log(delta / baseBucketSize) / log(1+growthFactor) + 1
// = (log(delta) - log(baseBucketSize)) * (1 / log(1+growthFactor)) + 1
b = int((math.Log(delta)-h.logBaseBucketSize)*h.oneOverLogOnePlusGrowthFactor + 1)
}
if b >= len(h.Buckets) {
return 0, fmt.Errorf("no bucket for value: %d", value)
}
return b, nil
}
// Merge takes another histogram h2, and merges its content into h.
// The two histograms must be created by equivalent HistogramOptions.
func (h *Histogram) Merge(h2 *Histogram) {
if h.opts != h2.opts {
log.Fatalf("failed to merge histograms, created by inequivalent options")
}
h.Count += h2.Count
h.Sum += h2.Sum
h.SumOfSquares += h2.SumOfSquares
if h2.Min < h.Min {
h.Min = h2.Min
}
if h2.Max > h.Max {
h.Max = h2.Max
}
for i, b := range h2.Buckets {
h.Buckets[i].Count += b.Count
}
}

View File

@ -1,302 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"strconv"
"time"
)
// Features contains most fields for a benchmark
type Features struct {
NetworkMode string
EnableTrace bool
Latency time.Duration
Kbps int
Mtu int
MaxConcurrentCalls int
ReqSizeBytes int
RespSizeBytes int
EnableCompressor bool
EnableChannelz bool
}
// String returns the textual output of the Features as string.
func (f Features) String() string {
return fmt.Sprintf("traceMode_%t-latency_%s-kbps_%#v-MTU_%#v-maxConcurrentCalls_"+
"%#v-reqSize_%#vB-respSize_%#vB-Compressor_%t", f.EnableTrace,
f.Latency.String(), f.Kbps, f.Mtu, f.MaxConcurrentCalls, f.ReqSizeBytes, f.RespSizeBytes, f.EnableCompressor)
}
// ConciseString returns the concise textual output of the Features as string, skipping
// setting with default value.
func (f Features) ConciseString() string {
noneEmptyPos := []bool{f.EnableTrace, f.Latency != 0, f.Kbps != 0, f.Mtu != 0, true, true, true, f.EnableCompressor, f.EnableChannelz}
return PartialPrintString(noneEmptyPos, f, false)
}
// PartialPrintString can print certain features with different format.
func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string {
s := ""
var (
prefix, suffix, linker string
isNetwork bool
)
if shared {
suffix = "\n"
linker = ": "
} else {
prefix = "-"
linker = "_"
}
if noneEmptyPos[0] {
s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableTrace, suffix)
}
if shared && f.NetworkMode != "" {
s += fmt.Sprintf("Network: %s \n", f.NetworkMode)
isNetwork = true
}
if !isNetwork {
if noneEmptyPos[1] {
s += fmt.Sprintf("%slatency%s%s%s", prefix, linker, f.Latency.String(), suffix)
}
if noneEmptyPos[2] {
s += fmt.Sprintf("%skbps%s%#v%s", prefix, linker, f.Kbps, suffix)
}
if noneEmptyPos[3] {
s += fmt.Sprintf("%sMTU%s%#v%s", prefix, linker, f.Mtu, suffix)
}
}
if noneEmptyPos[4] {
s += fmt.Sprintf("%sCallers%s%#v%s", prefix, linker, f.MaxConcurrentCalls, suffix)
}
if noneEmptyPos[5] {
s += fmt.Sprintf("%sreqSize%s%#vB%s", prefix, linker, f.ReqSizeBytes, suffix)
}
if noneEmptyPos[6] {
s += fmt.Sprintf("%srespSize%s%#vB%s", prefix, linker, f.RespSizeBytes, suffix)
}
if noneEmptyPos[7] {
s += fmt.Sprintf("%sCompressor%s%t%s", prefix, linker, f.EnableCompressor, suffix)
}
if noneEmptyPos[8] {
s += fmt.Sprintf("%sChannelz%s%t%s", prefix, linker, f.EnableChannelz, suffix)
}
return s
}
type percentLatency struct {
Percent int
Value time.Duration
}
// BenchResults records features and result of a benchmark.
type BenchResults struct {
RunMode string
Features Features
Latency []percentLatency
Operations int
NsPerOp int64
AllocedBytesPerOp int64
AllocsPerOp int64
SharedPosion []bool
}
// SetBenchmarkResult sets features of benchmark and basic results.
func (stats *Stats) SetBenchmarkResult(mode string, features Features, o int, allocdBytes, allocs int64, sharedPos []bool) {
stats.result.RunMode = mode
stats.result.Features = features
stats.result.Operations = o
stats.result.AllocedBytesPerOp = allocdBytes
stats.result.AllocsPerOp = allocs
stats.result.SharedPosion = sharedPos
}
// GetBenchmarkResults returns the result of the benchmark including features and result.
func (stats *Stats) GetBenchmarkResults() BenchResults {
return stats.result
}
// BenchString output latency stats as the format as time + unit.
func (stats *Stats) BenchString() string {
stats.maybeUpdate()
s := stats.result
res := s.RunMode + "-" + s.Features.String() + ": \n"
if len(s.Latency) != 0 {
var statsUnit = s.Latency[0].Value
var timeUnit = fmt.Sprintf("%v", statsUnit)[1:]
for i := 1; i < len(s.Latency)-1; i++ {
res += fmt.Sprintf("%d_Latency: %s %s \t", s.Latency[i].Percent,
strconv.FormatFloat(float64(s.Latency[i].Value)/float64(statsUnit), 'f', 4, 64), timeUnit)
}
res += fmt.Sprintf("Avg latency: %s %s \t",
strconv.FormatFloat(float64(s.Latency[len(s.Latency)-1].Value)/float64(statsUnit), 'f', 4, 64), timeUnit)
}
res += fmt.Sprintf("Count: %v \t", s.Operations)
res += fmt.Sprintf("%v Bytes/op\t", s.AllocedBytesPerOp)
res += fmt.Sprintf("%v Allocs/op\t", s.AllocsPerOp)
return res
}
// Stats is a simple helper for gathering additional statistics like histogram
// during benchmarks. This is not thread safe.
type Stats struct {
numBuckets int
unit time.Duration
min, max int64
histogram *Histogram
durations durationSlice
dirty bool
sortLatency bool
result BenchResults
}
type durationSlice []time.Duration
// NewStats creates a new Stats instance. If numBuckets is not positive,
// the default value (16) will be used.
func NewStats(numBuckets int) *Stats {
if numBuckets <= 0 {
numBuckets = 16
}
return &Stats{
// Use one more bucket for the last unbounded bucket.
numBuckets: numBuckets + 1,
durations: make(durationSlice, 0, 100000),
}
}
// Add adds an elapsed time per operation to the stats.
func (stats *Stats) Add(d time.Duration) {
stats.durations = append(stats.durations, d)
stats.dirty = true
}
// Clear resets the stats, removing all values.
func (stats *Stats) Clear() {
stats.durations = stats.durations[:0]
stats.histogram = nil
stats.dirty = false
stats.result = BenchResults{}
}
//Sort method for durations
func (a durationSlice) Len() int { return len(a) }
func (a durationSlice) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a durationSlice) Less(i, j int) bool { return a[i] < a[j] }
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
// maybeUpdate updates internal stat data if there was any newly added
// stats since this was updated.
func (stats *Stats) maybeUpdate() {
if !stats.dirty {
return
}
if stats.sortLatency {
sort.Sort(stats.durations)
stats.min = int64(stats.durations[0])
stats.max = int64(stats.durations[len(stats.durations)-1])
}
stats.min = math.MaxInt64
stats.max = 0
for _, d := range stats.durations {
if stats.min > int64(d) {
stats.min = int64(d)
}
if stats.max < int64(d) {
stats.max = int64(d)
}
}
// Use the largest unit that can represent the minimum time duration.
stats.unit = time.Nanosecond
for _, u := range []time.Duration{time.Microsecond, time.Millisecond, time.Second} {
if stats.min <= int64(u) {
break
}
stats.unit = u
}
numBuckets := stats.numBuckets
if n := int(stats.max - stats.min + 1); n < numBuckets {
numBuckets = n
}
stats.histogram = NewHistogram(HistogramOptions{
NumBuckets: numBuckets,
// max-min(lower bound of last bucket) = (1 + growthFactor)^(numBuckets-2) * baseBucketSize.
GrowthFactor: math.Pow(float64(stats.max-stats.min), 1/float64(numBuckets-2)) - 1,
BaseBucketSize: 1.0,
MinValue: stats.min})
for _, d := range stats.durations {
stats.histogram.Add(int64(d))
}
stats.dirty = false
if stats.durations.Len() != 0 {
var percentToObserve = []int{50, 90, 99}
// First data record min unit from the latency result.
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: -1, Value: stats.unit})
for _, position := range percentToObserve {
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: position, Value: stats.durations[max(stats.histogram.Count*int64(position)/100-1, 0)]})
}
// Last data record the average latency.
avg := float64(stats.histogram.Sum) / float64(stats.histogram.Count)
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: -1, Value: time.Duration(avg)})
}
}
// SortLatency blocks the output
func (stats *Stats) SortLatency() {
stats.sortLatency = true
}
// Print writes textual output of the Stats.
func (stats *Stats) Print(w io.Writer) {
stats.maybeUpdate()
if stats.histogram == nil {
fmt.Fprint(w, "Histogram (empty)\n")
} else {
fmt.Fprintf(w, "Histogram (unit: %s)\n", fmt.Sprintf("%v", stats.unit)[1:])
stats.histogram.PrintWithUnit(w, float64(stats.unit))
}
}
// String returns the textual output of the Stats as string.
func (stats *Stats) String() string {
var b bytes.Buffer
stats.Print(&b)
return b.String()
}

View File

@ -1,208 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bufio"
"bytes"
"fmt"
"os"
"runtime"
"sort"
"strings"
"sync"
"testing"
)
var (
curB *testing.B
curBenchName string
curStats map[string]*Stats
orgStdout *os.File
nextOutPos int
injectCond *sync.Cond
injectDone chan struct{}
)
// AddStats adds a new unnamed Stats instance to the current benchmark. You need
// to run benchmarks by calling RunTestMain() to inject the stats to the
// benchmark results. If numBuckets is not positive, the default value (16) will
// be used. Please note that this calls b.ResetTimer() since it may be blocked
// until the previous benchmark stats is printed out. So AddStats() should
// typically be called at the very beginning of each benchmark function.
func AddStats(b *testing.B, numBuckets int) *Stats {
return AddStatsWithName(b, "", numBuckets)
}
// AddStatsWithName adds a new named Stats instance to the current benchmark.
// With this, you can add multiple stats in a single benchmark. You need
// to run benchmarks by calling RunTestMain() to inject the stats to the
// benchmark results. If numBuckets is not positive, the default value (16) will
// be used. Please note that this calls b.ResetTimer() since it may be blocked
// until the previous benchmark stats is printed out. So AddStatsWithName()
// should typically be called at the very beginning of each benchmark function.
func AddStatsWithName(b *testing.B, name string, numBuckets int) *Stats {
var benchName string
for i := 1; ; i++ {
pc, _, _, ok := runtime.Caller(i)
if !ok {
panic("benchmark function not found")
}
p := strings.Split(runtime.FuncForPC(pc).Name(), ".")
benchName = p[len(p)-1]
if strings.HasPrefix(benchName, "run") {
break
}
}
procs := runtime.GOMAXPROCS(-1)
if procs != 1 {
benchName = fmt.Sprintf("%s-%d", benchName, procs)
}
stats := NewStats(numBuckets)
if injectCond != nil {
// We need to wait until the previous benchmark stats is printed out.
injectCond.L.Lock()
for curB != nil && curBenchName != benchName {
injectCond.Wait()
}
curB = b
curBenchName = benchName
curStats[name] = stats
injectCond.L.Unlock()
}
b.ResetTimer()
return stats
}
// RunTestMain runs the tests with enabling injection of benchmark stats. It
// returns an exit code to pass to os.Exit.
func RunTestMain(m *testing.M) int {
startStatsInjector()
defer stopStatsInjector()
return m.Run()
}
// startStatsInjector starts stats injection to benchmark results.
func startStatsInjector() {
orgStdout = os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
nextOutPos = 0
resetCurBenchStats()
injectCond = sync.NewCond(&sync.Mutex{})
injectDone = make(chan struct{})
go func() {
defer close(injectDone)
scanner := bufio.NewScanner(r)
scanner.Split(splitLines)
for scanner.Scan() {
injectStatsIfFinished(scanner.Text())
}
if err := scanner.Err(); err != nil {
panic(err)
}
}()
}
// stopStatsInjector stops stats injection and restores os.Stdout.
func stopStatsInjector() {
os.Stdout.Close()
<-injectDone
injectCond = nil
os.Stdout = orgStdout
}
// splitLines is a split function for a bufio.Scanner that returns each line
// of text, teeing texts to the original stdout even before each line ends.
func splitLines(data []byte, eof bool) (advance int, token []byte, err error) {
if eof && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
orgStdout.Write(data[nextOutPos : i+1])
nextOutPos = 0
return i + 1, data[0:i], nil
}
orgStdout.Write(data[nextOutPos:])
nextOutPos = len(data)
if eof {
// This is a final, non-terminated line. Return it.
return len(data), data, nil
}
return 0, nil, nil
}
// injectStatsIfFinished prints out the stats if the current benchmark finishes.
func injectStatsIfFinished(line string) {
injectCond.L.Lock()
defer injectCond.L.Unlock()
// We assume that the benchmark results start with "Benchmark".
if curB == nil || !strings.HasPrefix(line, "Benchmark") {
return
}
if !curB.Failed() {
// Output all stats in alphabetical order.
names := make([]string, 0, len(curStats))
for name := range curStats {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
stats := curStats[name]
// The output of stats starts with a header like "Histogram (unit: ms)"
// followed by statistical properties and the buckets. Add the stats name
// if it is a named stats and indent them as Go testing outputs.
lines := strings.Split(stats.String(), "\n")
if n := len(lines); n > 0 {
if name != "" {
name = ": " + name
}
fmt.Fprintf(orgStdout, "--- %s%s\n", lines[0], name)
for _, line := range lines[1 : n-1] {
fmt.Fprintf(orgStdout, "\t%s\n", line)
}
}
}
}
resetCurBenchStats()
injectCond.Signal()
}
// resetCurBenchStats resets the current benchmark stats.
func resetCurBenchStats() {
curB = nil
curBenchName = ""
curStats = make(map[string]*Stats)
}

View File

@ -1,386 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"context"
"flag"
"math"
"runtime"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
type lockingHistogram struct {
mu sync.Mutex
histogram *stats.Histogram
}
func (h *lockingHistogram) add(value int64) {
h.mu.Lock()
defer h.mu.Unlock()
h.histogram.Add(value)
}
// swap sets h.histogram to o and returns its old value.
func (h *lockingHistogram) swap(o *stats.Histogram) *stats.Histogram {
h.mu.Lock()
defer h.mu.Unlock()
old := h.histogram
h.histogram = o
return old
}
func (h *lockingHistogram) mergeInto(merged *stats.Histogram) {
h.mu.Lock()
defer h.mu.Unlock()
merged.Merge(h.histogram)
}
type benchmarkClient struct {
closeConns func()
stop chan bool
lastResetTime time.Time
histogramOptions stats.HistogramOptions
lockingHistograms []lockingHistogram
rusageLastReset *syscall.Rusage
}
func printClientConfig(config *testpb.ClientConfig) {
// Some config options are ignored:
// - client type:
// will always create sync client
// - async client threads.
// - core list
grpclog.Infof(" * client type: %v (ignored, always creates sync client)", config.ClientType)
grpclog.Infof(" * async client threads: %v (ignored)", config.AsyncClientThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
grpclog.Infof(" - rpcs per chann: %v", config.OutstandingRpcsPerChannel)
grpclog.Infof(" - channel number: %v", config.ClientChannels)
grpclog.Infof(" - load params: %v", config.LoadParams)
grpclog.Infof(" - rpc type: %v", config.RpcType)
grpclog.Infof(" - histogram params: %v", config.HistogramParams)
grpclog.Infof(" - server targets: %v", config.ServerTargets)
}
func setupClientEnv(config *testpb.ClientConfig) {
// Use all cpu cores available on machine by default.
// TODO: Revisit this for the optimal default setup.
if config.CoreLimit > 0 {
runtime.GOMAXPROCS(int(config.CoreLimit))
} else {
runtime.GOMAXPROCS(runtime.NumCPU())
}
}
// createConns creates connections according to given config.
// It returns the connections and corresponding function to close them.
// It returns non-nil error if there is anything wrong.
func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error) {
var opts []grpc.DialOption
// Sanity check for client type.
switch config.ClientType {
case testpb.ClientType_SYNC_CLIENT:
case testpb.ClientType_ASYNC_CLIENT:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown client type: %v", config.ClientType)
}
// Check and set security options.
if config.SecurityParams != nil {
if *caFile == "" {
*caFile = testdata.Path("ca.pem")
}
creds, err := credentials.NewClientTLSFromFile(*caFile, config.SecurityParams.ServerHostOverride)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "failed to create TLS credentials %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithInsecure())
}
// Use byteBufCodec if it is required.
if config.PayloadConfig != nil {
switch config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(byteBufCodec{})))
case *testpb.PayloadConfig_SimpleParams:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
// Create connections.
connCount := int(config.ClientChannels)
conns := make([]*grpc.ClientConn, connCount)
for connIndex := 0; connIndex < connCount; connIndex++ {
conns[connIndex] = benchmark.NewClientConn(config.ServerTargets[connIndex%len(config.ServerTargets)], opts...)
}
return conns, func() {
for _, conn := range conns {
conn.Close()
}
}, nil
}
func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error {
// Read payload size and type from config.
var (
payloadReqSize, payloadRespSize int
payloadType string
)
if config.PayloadConfig != nil {
switch c := config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
payloadReqSize = int(c.BytebufParams.ReqSize)
payloadRespSize = int(c.BytebufParams.RespSize)
payloadType = "bytebuf"
case *testpb.PayloadConfig_SimpleParams:
payloadReqSize = int(c.SimpleParams.ReqSize)
payloadRespSize = int(c.SimpleParams.RespSize)
payloadType = "protobuf"
default:
return status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
// TODO add open loop distribution.
switch config.LoadParams.Load.(type) {
case *testpb.LoadParams_ClosedLoop:
case *testpb.LoadParams_Poisson:
return status.Errorf(codes.Unimplemented, "unsupported load params: %v", config.LoadParams)
default:
return status.Errorf(codes.InvalidArgument, "unknown load params: %v", config.LoadParams)
}
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
switch config.RpcType {
case testpb.RpcType_UNARY:
bc.doCloseLoopUnary(conns, rpcCountPerConn, payloadReqSize, payloadRespSize)
// TODO open loop.
case testpb.RpcType_STREAMING:
bc.doCloseLoopStreaming(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType)
// TODO open loop.
default:
return status.Errorf(codes.InvalidArgument, "unknown rpc type: %v", config.RpcType)
}
return nil
}
func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) {
printClientConfig(config)
// Set running environment like how many cores to use.
setupClientEnv(config)
conns, closeConns, err := createConns(config)
if err != nil {
return nil, err
}
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
bc := &benchmarkClient{
histogramOptions: stats.HistogramOptions{
NumBuckets: int(math.Log(config.HistogramParams.MaxPossible)/math.Log(1+config.HistogramParams.Resolution)) + 1,
GrowthFactor: config.HistogramParams.Resolution,
BaseBucketSize: (1 + config.HistogramParams.Resolution),
MinValue: 0,
},
lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)),
stop: make(chan bool),
lastResetTime: time.Now(),
closeConns: closeConns,
rusageLastReset: syscall.GetRusage(),
}
if err = performRPCs(config, conns, bc); err != nil {
// Close all connections if performRPCs failed.
closeConns()
return nil, err
}
return bc, nil
}
func (bc *benchmarkClient) doCloseLoopUnary(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int) {
for ic, conn := range conns {
client := testpb.NewBenchmarkServiceClient(conn)
// For each connection, create rpcCountPerConn goroutines to do rpc.
for j := 0; j < rpcCountPerConn; j++ {
// Create histogram for each goroutine.
idx := ic*rpcCountPerConn + j
bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions)
// Start goroutine on the created mutex and histogram.
go func(idx int) {
// TODO: do warm up if necessary.
// Now relying on worker client to reserve time to do warm up.
// The worker client needs to wait for some time after client is created,
// before starting benchmark.
done := make(chan bool)
for {
go func() {
start := time.Now()
if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil {
select {
case <-bc.stop:
case done <- false:
}
return
}
elapse := time.Since(start)
bc.lockingHistograms[idx].add(int64(elapse))
select {
case <-bc.stop:
case done <- true:
}
}()
select {
case <-bc.stop:
return
case <-done:
}
}
}(idx)
}
}
}
func (bc *benchmarkClient) doCloseLoopStreaming(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string) {
var doRPC func(testpb.BenchmarkService_StreamingCallClient, int, int) error
if payloadType == "bytebuf" {
doRPC = benchmark.DoByteBufStreamingRoundTrip
} else {
doRPC = benchmark.DoStreamingRoundTrip
}
for ic, conn := range conns {
// For each connection, create rpcCountPerConn goroutines to do rpc.
for j := 0; j < rpcCountPerConn; j++ {
c := testpb.NewBenchmarkServiceClient(conn)
stream, err := c.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("%v.StreamingCall(_) = _, %v", c, err)
}
// Create histogram for each goroutine.
idx := ic*rpcCountPerConn + j
bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions)
// Start goroutine on the created mutex and histogram.
go func(idx int) {
// TODO: do warm up if necessary.
// Now relying on worker client to reserve time to do warm up.
// The worker client needs to wait for some time after client is created,
// before starting benchmark.
for {
start := time.Now()
if err := doRPC(stream, reqSize, respSize); err != nil {
return
}
elapse := time.Since(start)
bc.lockingHistograms[idx].add(int64(elapse))
select {
case <-bc.stop:
return
default:
}
}
}(idx)
}
}
}
// getStats returns the stats for benchmark client.
// It resets lastResetTime and all histograms if argument reset is true.
func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
var wallTimeElapsed, uTimeElapsed, sTimeElapsed float64
mergedHistogram := stats.NewHistogram(bc.histogramOptions)
if reset {
// Merging histogram may take some time.
// Put all histograms aside and merge later.
toMerge := make([]*stats.Histogram, len(bc.lockingHistograms))
for i := range bc.lockingHistograms {
toMerge[i] = bc.lockingHistograms[i].swap(stats.NewHistogram(bc.histogramOptions))
}
for i := 0; i < len(toMerge); i++ {
mergedHistogram.Merge(toMerge[i])
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
latestRusage := syscall.GetRusage()
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, latestRusage)
bc.rusageLastReset = latestRusage
bc.lastResetTime = time.Now()
} else {
// Merge only, not reset.
for i := range bc.lockingHistograms {
bc.lockingHistograms[i].mergeInto(mergedHistogram)
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, syscall.GetRusage())
}
b := make([]uint32, len(mergedHistogram.Buckets))
for i, v := range mergedHistogram.Buckets {
b[i] = uint32(v.Count)
}
return &testpb.ClientStats{
Latencies: &testpb.HistogramData{
Bucket: b,
MinSeen: float64(mergedHistogram.Min),
MaxSeen: float64(mergedHistogram.Max),
Sum: float64(mergedHistogram.Sum),
SumOfSquares: float64(mergedHistogram.SumOfSquares),
Count: float64(mergedHistogram.Count),
},
TimeElapsed: wallTimeElapsed,
TimeUser: uTimeElapsed,
TimeSystem: sTimeElapsed,
}
}
func (bc *benchmarkClient) shutdown() {
close(bc.stop)
bc.closeConns()
}

View File

@ -1,184 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"net"
"runtime"
"strconv"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var (
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
keyFile = flag.String("tls_key_file", "", "The TLS key file")
)
type benchmarkServer struct {
port int
cores int
closeFunc func()
mu sync.RWMutex
lastResetTime time.Time
rusageLastReset *syscall.Rusage
}
func printServerConfig(config *testpb.ServerConfig) {
// Some config options are ignored:
// - server type:
// will always start sync server
// - async server threads
// - core list
grpclog.Infof(" * server type: %v (ignored, always starts sync server)", config.ServerType)
grpclog.Infof(" * async server threads: %v (ignored)", config.AsyncServerThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - port: %v", config.Port)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
}
func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchmarkServer, error) {
printServerConfig(config)
// Use all cpu cores available on machine by default.
// TODO: Revisit this for the optimal default setup.
numOfCores := runtime.NumCPU()
if config.CoreLimit > 0 {
numOfCores = int(config.CoreLimit)
}
runtime.GOMAXPROCS(numOfCores)
var opts []grpc.ServerOption
// Sanity check for server type.
switch config.ServerType {
case testpb.ServerType_SYNC_SERVER:
case testpb.ServerType_ASYNC_SERVER:
case testpb.ServerType_ASYNC_GENERIC_SERVER:
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown server type: %v", config.ServerType)
}
// Set security options.
if config.SecurityParams != nil {
if *certFile == "" {
*certFile = testdata.Path("server1.pem")
}
if *keyFile == "" {
*keyFile = testdata.Path("server1.key")
}
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil {
grpclog.Fatalf("failed to generate credentials %v", err)
}
opts = append(opts, grpc.Creds(creds))
}
// Priority: config.Port > serverPort > default (0).
port := int(config.Port)
if port == 0 {
port = serverPort
}
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
addr := lis.Addr().String()
// Create different benchmark server according to config.
var closeFunc func()
if config.PayloadConfig != nil {
switch payload := config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.CustomCodec(byteBufCodec{}))
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "bytebuf",
Metadata: payload.BytebufParams.RespSize,
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_SimpleParams:
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_ComplexParams:
return nil, status.Errorf(codes.Unimplemented, "unsupported payload config: %v", config.PayloadConfig)
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
} else {
// Start protobuf server if payload config is nil.
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
}
grpclog.Infof("benchmark server listening at %v", addr)
addrSplitted := strings.Split(addr, ":")
p, err := strconv.Atoi(addrSplitted[len(addrSplitted)-1])
if err != nil {
grpclog.Fatalf("failed to get port number from server address: %v", err)
}
return &benchmarkServer{
port: p,
cores: numOfCores,
closeFunc: closeFunc,
lastResetTime: time.Now(),
rusageLastReset: syscall.GetRusage(),
}, nil
}
// getStats returns the stats for benchmark server.
// It resets lastResetTime if argument reset is true.
func (bs *benchmarkServer) getStats(reset bool) *testpb.ServerStats {
bs.mu.RLock()
defer bs.mu.RUnlock()
wallTimeElapsed := time.Since(bs.lastResetTime).Seconds()
rusageLatest := syscall.GetRusage()
uTimeElapsed, sTimeElapsed := syscall.CPUTimeDiff(bs.rusageLastReset, rusageLatest)
if reset {
bs.lastResetTime = time.Now()
bs.rusageLastReset = rusageLatest
}
return &testpb.ServerStats{
TimeElapsed: wallTimeElapsed,
TimeUser: uTimeElapsed,
TimeSystem: sTimeElapsed,
}
}

View File

@ -1,230 +0,0 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"context"
"flag"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"strconv"
"time"
"google.golang.org/grpc"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)
var (
driverPort = flag.Int("driver_port", 10000, "port for communication with driver")
serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message")
pprofPort = flag.Int("pprof_port", -1, "Port for pprof debug server to listen on. Pprof server doesn't start if unset")
blockProfRate = flag.Int("block_prof_rate", 0, "fraction of goroutine blocking events to report in blocking profile")
)
type byteBufCodec struct {
}
func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.(*[]byte)
if !ok {
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
}
return *b, nil
}
func (byteBufCodec) Unmarshal(data []byte, v interface{}) error {
b, ok := v.(*[]byte)
if !ok {
return fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
}
*b = data
return nil
}
func (byteBufCodec) String() string {
return "bytebuffer"
}
// workerServer implements WorkerService rpc handlers.
// It can create benchmarkServer or benchmarkClient on demand.
type workerServer struct {
stop chan<- bool
serverPort int
}
func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) error {
var bs *benchmarkServer
defer func() {
// Close benchmark server when stream ends.
grpclog.Infof("closing benchmark server")
if bs != nil {
bs.closeFunc()
}
}()
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var out *testpb.ServerStatus
switch argtype := in.Argtype.(type) {
case *testpb.ServerArgs_Setup:
grpclog.Infof("server setup received:")
if bs != nil {
grpclog.Infof("server setup received when server already exists, closing the existing server")
bs.closeFunc()
}
bs, err = startBenchmarkServer(argtype.Setup, s.serverPort)
if err != nil {
return err
}
out = &testpb.ServerStatus{
Stats: bs.getStats(false),
Port: int32(bs.port),
Cores: int32(bs.cores),
}
case *testpb.ServerArgs_Mark:
grpclog.Infof("server mark received:")
grpclog.Infof(" - %v", argtype)
if bs == nil {
return status.Error(codes.InvalidArgument, "server does not exist when mark received")
}
out = &testpb.ServerStatus{
Stats: bs.getStats(argtype.Mark.Reset_),
Port: int32(bs.port),
Cores: int32(bs.cores),
}
}
if err := stream.Send(out); err != nil {
return err
}
}
}
func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error {
var bc *benchmarkClient
defer func() {
// Shut down benchmark client when stream ends.
grpclog.Infof("shuting down benchmark client")
if bc != nil {
bc.shutdown()
}
}()
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var out *testpb.ClientStatus
switch t := in.Argtype.(type) {
case *testpb.ClientArgs_Setup:
grpclog.Infof("client setup received:")
if bc != nil {
grpclog.Infof("client setup received when client already exists, shuting down the existing client")
bc.shutdown()
}
bc, err = startBenchmarkClient(t.Setup)
if err != nil {
return err
}
out = &testpb.ClientStatus{
Stats: bc.getStats(false),
}
case *testpb.ClientArgs_Mark:
grpclog.Infof("client mark received:")
grpclog.Infof(" - %v", t)
if bc == nil {
return status.Error(codes.InvalidArgument, "client does not exist when mark received")
}
out = &testpb.ClientStatus{
Stats: bc.getStats(t.Mark.Reset_),
}
}
if err := stream.Send(out); err != nil {
return err
}
}
}
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
grpclog.Infof("core count: %v", runtime.NumCPU())
return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil
}
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {
grpclog.Infof("quitting worker")
s.stop <- true
return &testpb.Void{}, nil
}
func main() {
grpc.EnableTracing = false
flag.Parse()
lis, err := net.Listen("tcp", ":"+strconv.Itoa(*driverPort))
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
grpclog.Infof("worker listening at port %v", *driverPort)
s := grpc.NewServer()
stop := make(chan bool)
testpb.RegisterWorkerServiceServer(s, &workerServer{
stop: stop,
serverPort: *serverPort,
})
go func() {
<-stop
// Wait for 1 second before stopping the server to make sure the return value of QuitWorker is sent to client.
// TODO revise this once server graceful stop is supported in gRPC.
time.Sleep(time.Second)
s.Stop()
}()
runtime.SetBlockProfileRate(*blockProfRate)
if *pprofPort >= 0 {
go func() {
grpclog.Infoln("Starting pprof server on port " + strconv.Itoa(*pprofPort))
grpclog.Infoln(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil))
}()
}
s.Serve(lis)
}

View File

@ -1,293 +0,0 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"fmt"
"io"
"math"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
)
var (
expectedRequest = "ping"
expectedResponse = "pong"
weirdError = "format verbs: %v%s"
sizeLargeErr = 1024 * 1024
canceled = 0
)
type testCodec struct {
}
func (testCodec) Marshal(v interface{}) ([]byte, error) {
return []byte(*(v.(*string))), nil
}
func (testCodec) Unmarshal(data []byte, v interface{}) error {
*(v.(*string)) = string(data)
return nil
}
func (testCodec) String() string {
return "test"
}
type testStreamHandler struct {
port string
t transport.ServerTransport
}
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
p := &parser{r: s}
for {
pf, req, err := p.recvMsg(math.MaxInt32)
if err == io.EOF {
break
}
if err != nil {
return
}
if pf != compressionNone {
t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
return
}
var v string
codec := testCodec{}
if err := codec.Unmarshal(req, &v); err != nil {
t.Errorf("Failed to unmarshal the received message: %v", err)
return
}
if v == "weird error" {
h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
return
}
if v == "canceled" {
canceled++
h.t.WriteStatus(s, status.New(codes.Internal, ""))
return
}
if v == "port" {
h.t.WriteStatus(s, status.New(codes.Internal, h.port))
return
}
if v != expectedRequest {
h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
return
}
}
// send a response back to end the stream.
data, err := encode(testCodec{}, &expectedResponse)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
hdr, payload := msgHeader(data, nil)
h.t.Write(s, hdr, payload, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
type server struct {
lis net.Listener
port string
addr string
startedErr chan error // sent nil or an error after server starts
mu sync.Mutex
conns map[transport.ServerTransport]bool
}
func newTestServer() *server {
return &server{startedErr: make(chan error, 1)}
}
// start starts server. Other goroutines should block on s.startedErr for further operations.
func (s *server) start(t *testing.T, port int, maxStreams uint32) {
var err error
if port == 0 {
s.lis, err = net.Listen("tcp", "localhost:0")
} else {
s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
}
if err != nil {
s.startedErr <- fmt.Errorf("failed to listen: %v", err)
return
}
s.addr = s.lis.Addr().String()
_, p, err := net.SplitHostPort(s.addr)
if err != nil {
s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
return
}
s.port = p
s.conns = make(map[transport.ServerTransport]bool)
s.startedErr <- nil
for {
conn, err := s.lis.Accept()
if err != nil {
return
}
config := &transport.ServerConfig{
MaxStreams: maxStreams,
}
st, err := transport.NewServerTransport("http2", conn, config)
if err != nil {
continue
}
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
st.Close()
return
}
s.conns[st] = true
s.mu.Unlock()
h := &testStreamHandler{
port: s.port,
t: st,
}
go st.HandleStreams(func(s *transport.Stream) {
go h.handleStream(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
}
}
func (s *server) wait(t *testing.T, timeout time.Duration) {
select {
case err := <-s.startedErr:
if err != nil {
t.Fatal(err)
}
case <-time.After(timeout):
t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
}
}
func (s *server) stop() {
s.lis.Close()
s.mu.Lock()
for c := range s.conns {
c.Close()
}
s.conns = nil
s.mu.Unlock()
}
func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
server := newTestServer()
go server.start(t, port, maxStreams)
server.wait(t, 2*time.Second)
addr := "localhost:" + server.port
cc, err := Dial(addr, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
return server, cc
}
func TestInvoke(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
server.stop()
}
func TestInvokeLargeErr(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "hello"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr)
}
cc.Close()
server.stop()
}
// TestInvokeErrorSpecialChars checks that error messages don't get mangled.
func TestInvokeErrorSpecialChars(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "weird error"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if got, want := errorDesc(err), weirdError; got != want {
t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want)
}
cc.Close()
server.stop()
}
// TestInvokeCancel checks that an Invoke with a canceled context is not sent.
func TestInvokeCancel(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "canceled"
for i := 0; i < 100; i++ {
ctx, cancel := context.WithCancel(context.Background())
cancel()
cc.Invoke(ctx, "/foo/bar", &req, &reply)
}
if canceled != 0 {
t.Fatalf("received %d of 100 canceled requests", canceled)
}
cc.Close()
server.stop()
}
// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
// on a closed client will terminate.
func TestInvokeCancelClosedNonFailFast(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
cc.Close()
req := "hello"
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, FailFast(false)); err == nil {
t.Fatalf("canceled invoke on closed connection should fail")
}
server.stop()
}

File diff suppressed because it is too large Load Diff

View File

@ -1,105 +0,0 @@
// +build !appengine
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"github.com/golang/protobuf/ptypes"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
var opts []*channelzpb.SocketOption
if skopts.Linger != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionLinger{
Active: skopts.Linger.Onoff != 0,
Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_LINGER",
Additional: additional,
})
}
}
if skopts.RecvTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_RCVTIMEO",
Additional: additional,
})
}
}
if skopts.SendTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_SNDTIMEO",
Additional: additional,
})
}
}
if skopts.TCPInfo != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTcpInfo{
TcpiState: uint32(skopts.TCPInfo.State),
TcpiCaState: uint32(skopts.TCPInfo.Ca_state),
TcpiRetransmits: uint32(skopts.TCPInfo.Retransmits),
TcpiProbes: uint32(skopts.TCPInfo.Probes),
TcpiBackoff: uint32(skopts.TCPInfo.Backoff),
TcpiOptions: uint32(skopts.TCPInfo.Options),
// https://golang.org/pkg/syscall/#TCPInfo
// TCPInfo struct does not contain info about TcpiSndWscale and TcpiRcvWscale.
TcpiRto: skopts.TCPInfo.Rto,
TcpiAto: skopts.TCPInfo.Ato,
TcpiSndMss: skopts.TCPInfo.Snd_mss,
TcpiRcvMss: skopts.TCPInfo.Rcv_mss,
TcpiUnacked: skopts.TCPInfo.Unacked,
TcpiSacked: skopts.TCPInfo.Sacked,
TcpiLost: skopts.TCPInfo.Lost,
TcpiRetrans: skopts.TCPInfo.Retrans,
TcpiFackets: skopts.TCPInfo.Fackets,
TcpiLastDataSent: skopts.TCPInfo.Last_data_sent,
TcpiLastAckSent: skopts.TCPInfo.Last_ack_sent,
TcpiLastDataRecv: skopts.TCPInfo.Last_data_recv,
TcpiLastAckRecv: skopts.TCPInfo.Last_ack_recv,
TcpiPmtu: skopts.TCPInfo.Pmtu,
TcpiRcvSsthresh: skopts.TCPInfo.Rcv_ssthresh,
TcpiRtt: skopts.TCPInfo.Rtt,
TcpiRttvar: skopts.TCPInfo.Rttvar,
TcpiSndSsthresh: skopts.TCPInfo.Snd_ssthresh,
TcpiSndCwnd: skopts.TCPInfo.Snd_cwnd,
TcpiAdvmss: skopts.TCPInfo.Advmss,
TcpiReordering: skopts.TCPInfo.Reordering,
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "TCP_INFO",
Additional: additional,
})
}
}
return opts
}

View File

@ -1,30 +0,0 @@
// +build !linux appengine
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
return nil
}

View File

@ -1,33 +0,0 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/channelz/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/channelz/v1/channelz.proto > grpc/channelz/v1/channelz.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/channelz/v1/*.proto
popd
rm -f ../grpc_channelz_v1/*.pb.go
cp "$TMP"/grpc/channelz/v1/*.pb.go ../grpc_channelz_v1/

View File

@ -1,341 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package service provides an implementation for channelz service server.
package service
import (
"context"
"net"
"time"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
wrpb "github.com/golang/protobuf/ptypes/wrappers"
"google.golang.org/grpc"
channelzgrpc "google.golang.org/grpc/channelz/grpc_channelz_v1"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/status"
)
func init() {
channelz.TurnOn()
}
func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration {
return ptypes.DurationProto(time.Duration(sec*1e9 + usec*1e3))
}
// RegisterChannelzServiceToServer registers the channelz service to the given server.
func RegisterChannelzServiceToServer(s *grpc.Server) {
channelzgrpc.RegisterChannelzServer(s, newCZServer())
}
func newCZServer() channelzgrpc.ChannelzServer {
return &serverImpl{}
}
type serverImpl struct{}
func connectivityStateToProto(s connectivity.State) *channelzpb.ChannelConnectivityState {
switch s {
case connectivity.Idle:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_IDLE}
case connectivity.Connecting:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_CONNECTING}
case connectivity.Ready:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_READY}
case connectivity.TransientFailure:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE}
case connectivity.Shutdown:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_SHUTDOWN}
default:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_UNKNOWN}
}
}
func channelTraceToProto(ct *channelz.ChannelTrace) *channelzpb.ChannelTrace {
pbt := &channelzpb.ChannelTrace{}
pbt.NumEventsLogged = ct.EventNum
if ts, err := ptypes.TimestampProto(ct.CreationTime); err == nil {
pbt.CreationTimestamp = ts
}
var events []*channelzpb.ChannelTraceEvent
for _, e := range ct.Events {
cte := &channelzpb.ChannelTraceEvent{
Description: e.Desc,
Severity: channelzpb.ChannelTraceEvent_Severity(e.Severity),
}
if ts, err := ptypes.TimestampProto(e.Timestamp); err == nil {
cte.Timestamp = ts
}
if e.RefID != 0 {
switch e.RefType {
case channelz.RefChannel:
cte.ChildRef = &channelzpb.ChannelTraceEvent_ChannelRef{ChannelRef: &channelzpb.ChannelRef{ChannelId: e.RefID, Name: e.RefName}}
case channelz.RefSubChannel:
cte.ChildRef = &channelzpb.ChannelTraceEvent_SubchannelRef{SubchannelRef: &channelzpb.SubchannelRef{SubchannelId: e.RefID, Name: e.RefName}}
}
}
events = append(events, cte)
}
pbt.Events = events
return pbt
}
func channelMetricToProto(cm *channelz.ChannelMetric) *channelzpb.Channel {
c := &channelzpb.Channel{}
c.Ref = &channelzpb.ChannelRef{ChannelId: cm.ID, Name: cm.RefName}
c.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
c.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
c.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
c.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
c.SocketRef = sockets
c.Data.Trace = channelTraceToProto(cm.Trace)
return c
}
func subChannelMetricToProto(cm *channelz.SubChannelMetric) *channelzpb.Subchannel {
sc := &channelzpb.Subchannel{}
sc.Ref = &channelzpb.SubchannelRef{SubchannelId: cm.ID, Name: cm.RefName}
sc.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
sc.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
sc.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
sc.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
sc.SocketRef = sockets
sc.Data.Trace = channelTraceToProto(cm.Trace)
return sc
}
func securityToProto(se credentials.ChannelzSecurityValue) *channelzpb.Security {
switch v := se.(type) {
case *credentials.TLSChannelzSecurityValue:
return &channelzpb.Security{Model: &channelzpb.Security_Tls_{Tls: &channelzpb.Security_Tls{
CipherSuite: &channelzpb.Security_Tls_StandardName{StandardName: v.StandardName},
LocalCertificate: v.LocalCertificate,
RemoteCertificate: v.RemoteCertificate,
}}}
case *credentials.OtherChannelzSecurityValue:
otherSecurity := &channelzpb.Security_OtherSecurity{
Name: v.Name,
}
if anyval, err := ptypes.MarshalAny(v.Value); err == nil {
otherSecurity.Value = anyval
}
return &channelzpb.Security{Model: &channelzpb.Security_Other{Other: otherSecurity}}
}
return nil
}
func addrToProto(a net.Addr) *channelzpb.Address {
switch a.Network() {
case "udp":
// TODO: Address_OtherAddress{}. Need proto def for Value.
case "ip":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPAddr).IP}}}
case "ip+net":
// Note mask info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPNet).IP}}}
case "tcp":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.TCPAddr).IP, Port: int32(a.(*net.TCPAddr).Port)}}}
case "unix", "unixgram", "unixpacket":
return &channelzpb.Address{Address: &channelzpb.Address_UdsAddress_{UdsAddress: &channelzpb.Address_UdsAddress{Filename: a.String()}}}
default:
}
return &channelzpb.Address{}
}
func socketMetricToProto(sm *channelz.SocketMetric) *channelzpb.Socket {
s := &channelzpb.Socket{}
s.Ref = &channelzpb.SocketRef{SocketId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.SocketData{
StreamsStarted: sm.SocketData.StreamsStarted,
StreamsSucceeded: sm.SocketData.StreamsSucceeded,
StreamsFailed: sm.SocketData.StreamsFailed,
MessagesSent: sm.SocketData.MessagesSent,
MessagesReceived: sm.SocketData.MessagesReceived,
KeepAlivesSent: sm.SocketData.KeepAlivesSent,
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastLocalStreamCreatedTimestamp); err == nil {
s.Data.LastLocalStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastRemoteStreamCreatedTimestamp); err == nil {
s.Data.LastRemoteStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageSentTimestamp); err == nil {
s.Data.LastMessageSentTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageReceivedTimestamp); err == nil {
s.Data.LastMessageReceivedTimestamp = ts
}
s.Data.LocalFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.LocalFlowControlWindow}
s.Data.RemoteFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.RemoteFlowControlWindow}
if sm.SocketData.SocketOptions != nil {
s.Data.Option = sockoptToProto(sm.SocketData.SocketOptions)
}
if sm.SocketData.Security != nil {
s.Security = securityToProto(sm.SocketData.Security)
}
if sm.SocketData.LocalAddr != nil {
s.Local = addrToProto(sm.SocketData.LocalAddr)
}
if sm.SocketData.RemoteAddr != nil {
s.Remote = addrToProto(sm.SocketData.RemoteAddr)
}
s.RemoteName = sm.SocketData.RemoteName
return s
}
func (s *serverImpl) GetTopChannels(ctx context.Context, req *channelzpb.GetTopChannelsRequest) (*channelzpb.GetTopChannelsResponse, error) {
metrics, end := channelz.GetTopChannels(req.GetStartChannelId())
resp := &channelzpb.GetTopChannelsResponse{}
for _, m := range metrics {
resp.Channel = append(resp.Channel, channelMetricToProto(m))
}
resp.End = end
return resp, nil
}
func serverMetricToProto(sm *channelz.ServerMetric) *channelzpb.Server {
s := &channelzpb.Server{}
s.Ref = &channelzpb.ServerRef{ServerId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.ServerData{
CallsStarted: sm.ServerData.CallsStarted,
CallsSucceeded: sm.ServerData.CallsSucceeded,
CallsFailed: sm.ServerData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(sm.ServerData.LastCallStartedTimestamp); err == nil {
s.Data.LastCallStartedTimestamp = ts
}
sockets := make([]*channelzpb.SocketRef, 0, len(sm.ListenSockets))
for id, ref := range sm.ListenSockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
s.ListenSocket = sockets
return s
}
func (s *serverImpl) GetServers(ctx context.Context, req *channelzpb.GetServersRequest) (*channelzpb.GetServersResponse, error) {
metrics, end := channelz.GetServers(req.GetStartServerId())
resp := &channelzpb.GetServersResponse{}
for _, m := range metrics {
resp.Server = append(resp.Server, serverMetricToProto(m))
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetServerSockets(ctx context.Context, req *channelzpb.GetServerSocketsRequest) (*channelzpb.GetServerSocketsResponse, error) {
metrics, end := channelz.GetServerSockets(req.GetServerId(), req.GetStartSocketId())
resp := &channelzpb.GetServerSocketsResponse{}
for _, m := range metrics {
resp.SocketRef = append(resp.SocketRef, &channelzpb.SocketRef{SocketId: m.ID, Name: m.RefName})
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetChannel(ctx context.Context, req *channelzpb.GetChannelRequest) (*channelzpb.GetChannelResponse, error) {
var metric *channelz.ChannelMetric
if metric = channelz.GetChannel(req.GetChannelId()); metric == nil {
return &channelzpb.GetChannelResponse{}, nil
}
resp := &channelzpb.GetChannelResponse{Channel: channelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSubchannel(ctx context.Context, req *channelzpb.GetSubchannelRequest) (*channelzpb.GetSubchannelResponse, error) {
var metric *channelz.SubChannelMetric
if metric = channelz.GetSubChannel(req.GetSubchannelId()); metric == nil {
return &channelzpb.GetSubchannelResponse{}, nil
}
resp := &channelzpb.GetSubchannelResponse{Subchannel: subChannelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSocket(ctx context.Context, req *channelzpb.GetSocketRequest) (*channelzpb.GetSocketResponse, error) {
var metric *channelz.SocketMetric
if metric = channelz.GetSocket(req.GetSocketId()); metric == nil {
return &channelzpb.GetSocketResponse{}, nil
}
resp := &channelzpb.GetSocketResponse{Socket: socketMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetServer(ctx context.Context, req *channelzpb.GetServerRequest) (*channelzpb.GetServerResponse, error) {
return nil, status.Error(codes.Unimplemented, "GetServer not implemented")
}

View File

@ -1,152 +0,0 @@
// +build linux,!appengine
// +build 386 amd64
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// SocketOptions is only supported on linux system. The functions defined in
// this file are to parse the socket option field and the test is specifically
// to verify the behavior of socket option parsing.
package service
import (
"context"
"reflect"
"strconv"
"testing"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func init() {
// Assign protoToSocketOption to protoToSocketOpt in order to enable socket option
// data conversion from proto message to channelz defined struct.
protoToSocketOpt = protoToSocketOption
}
func convertToDuration(d *durpb.Duration) (sec int64, usec int64) {
if d != nil {
if dur, err := ptypes.Duration(d); err == nil {
sec = int64(int64(dur) / 1e9)
usec = (int64(dur) - sec*1e9) / 1e3
}
}
return
}
func protoToLinger(protoLinger *channelzpb.SocketOptionLinger) *unix.Linger {
linger := &unix.Linger{}
if protoLinger.GetActive() {
linger.Onoff = 1
}
lv, _ := convertToDuration(protoLinger.GetDuration())
linger.Linger = int32(lv)
return linger
}
func protoToSocketOption(skopts []*channelzpb.SocketOption) *channelz.SocketOptionData {
skdata := &channelz.SocketOptionData{}
for _, opt := range skopts {
switch opt.GetName() {
case "SO_LINGER":
protoLinger := &channelzpb.SocketOptionLinger{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoLinger)
if err == nil {
skdata.Linger = protoToLinger(protoLinger)
}
case "SO_RCVTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.RecvTimeout = protoToTime(protoTimeout)
}
case "SO_SNDTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.SendTimeout = protoToTime(protoTimeout)
}
case "TCP_INFO":
tcpi := &channelzpb.SocketOptionTcpInfo{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), tcpi)
if err == nil {
skdata.TCPInfo = &unix.TCPInfo{
State: uint8(tcpi.TcpiState),
Ca_state: uint8(tcpi.TcpiCaState),
Retransmits: uint8(tcpi.TcpiRetransmits),
Probes: uint8(tcpi.TcpiProbes),
Backoff: uint8(tcpi.TcpiBackoff),
Options: uint8(tcpi.TcpiOptions),
Rto: tcpi.TcpiRto,
Ato: tcpi.TcpiAto,
Snd_mss: tcpi.TcpiSndMss,
Rcv_mss: tcpi.TcpiRcvMss,
Unacked: tcpi.TcpiUnacked,
Sacked: tcpi.TcpiSacked,
Lost: tcpi.TcpiLost,
Retrans: tcpi.TcpiRetrans,
Fackets: tcpi.TcpiFackets,
Last_data_sent: tcpi.TcpiLastDataSent,
Last_ack_sent: tcpi.TcpiLastAckSent,
Last_data_recv: tcpi.TcpiLastDataRecv,
Last_ack_recv: tcpi.TcpiLastAckRecv,
Pmtu: tcpi.TcpiPmtu,
Rcv_ssthresh: tcpi.TcpiRcvSsthresh,
Rtt: tcpi.TcpiRtt,
Rttvar: tcpi.TcpiRttvar,
Snd_ssthresh: tcpi.TcpiSndSsthresh,
Snd_cwnd: tcpi.TcpiSndCwnd,
Advmss: tcpi.TcpiAdvmss,
Reordering: tcpi.TcpiReordering}
}
}
}
return skdata
}
func TestGetSocketOptions(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
socketOptions: &channelz.SocketOptionData{
Linger: &unix.Linger{Onoff: 1, Linger: 2},
RecvTimeout: &unix.Timeval{Sec: 10, Usec: 1},
SendTimeout: &unix.Timeval{},
TCPInfo: &unix.TCPInfo{State: 1},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View File

@ -1,686 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"context"
"fmt"
"net"
"reflect"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
)
func init() {
channelz.TurnOn()
}
type protoToSocketOptFunc func([]*channelzpb.SocketOption) *channelz.SocketOptionData
// protoToSocketOpt is used in function socketProtoToStruct to extract socket option
// data from unmarshaled proto message.
// It is only defined under linux, non-appengine environment on x86 architecture.
var protoToSocketOpt protoToSocketOptFunc
// emptyTime is used for detecting unset value of time.Time type.
// For go1.7 and earlier, ptypes.Timestamp will fill in the loc field of time.Time
// with &utcLoc. However zero value of a time.Time type value loc field is nil.
// This behavior will make reflect.DeepEqual fail upon unset time.Time field,
// and cause false positive fatal error.
// TODO: Go1.7 is no longer supported - does this need a change?
var emptyTime time.Time
type dummyChannel struct {
state connectivity.State
target string
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyChannel) ChannelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: d.state,
Target: d.target,
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummyServer struct {
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyServer) ChannelzMetric() *channelz.ServerInternalMetric {
return &channelz.ServerInternalMetric{
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummySocket struct {
streamsStarted int64
streamsSucceeded int64
streamsFailed int64
messagesSent int64
messagesReceived int64
keepAlivesSent int64
lastLocalStreamCreatedTimestamp time.Time
lastRemoteStreamCreatedTimestamp time.Time
lastMessageSentTimestamp time.Time
lastMessageReceivedTimestamp time.Time
localFlowControlWindow int64
remoteFlowControlWindow int64
socketOptions *channelz.SocketOptionData
localAddr net.Addr
remoteAddr net.Addr
security credentials.ChannelzSecurityValue
remoteName string
}
func (d *dummySocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
StreamsStarted: d.streamsStarted,
StreamsSucceeded: d.streamsSucceeded,
StreamsFailed: d.streamsFailed,
MessagesSent: d.messagesSent,
MessagesReceived: d.messagesReceived,
KeepAlivesSent: d.keepAlivesSent,
LastLocalStreamCreatedTimestamp: d.lastLocalStreamCreatedTimestamp,
LastRemoteStreamCreatedTimestamp: d.lastRemoteStreamCreatedTimestamp,
LastMessageSentTimestamp: d.lastMessageSentTimestamp,
LastMessageReceivedTimestamp: d.lastMessageReceivedTimestamp,
LocalFlowControlWindow: d.localFlowControlWindow,
RemoteFlowControlWindow: d.remoteFlowControlWindow,
SocketOptions: d.socketOptions,
LocalAddr: d.localAddr,
RemoteAddr: d.remoteAddr,
Security: d.security,
RemoteName: d.remoteName,
}
}
func channelProtoToStruct(c *channelzpb.Channel) *dummyChannel {
dc := &dummyChannel{}
pdata := c.GetData()
switch pdata.GetState().GetState() {
case channelzpb.ChannelConnectivityState_UNKNOWN:
// TODO: what should we set here?
case channelzpb.ChannelConnectivityState_IDLE:
dc.state = connectivity.Idle
case channelzpb.ChannelConnectivityState_CONNECTING:
dc.state = connectivity.Connecting
case channelzpb.ChannelConnectivityState_READY:
dc.state = connectivity.Ready
case channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE:
dc.state = connectivity.TransientFailure
case channelzpb.ChannelConnectivityState_SHUTDOWN:
dc.state = connectivity.Shutdown
}
dc.target = pdata.GetTarget()
dc.callsStarted = pdata.CallsStarted
dc.callsSucceeded = pdata.CallsSucceeded
dc.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
dc.lastCallStartedTimestamp = t
}
}
return dc
}
func serverProtoToStruct(s *channelzpb.Server) *dummyServer {
ds := &dummyServer{}
pdata := s.GetData()
ds.callsStarted = pdata.CallsStarted
ds.callsSucceeded = pdata.CallsSucceeded
ds.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastCallStartedTimestamp = t
}
}
return ds
}
func socketProtoToStruct(s *channelzpb.Socket) *dummySocket {
ds := &dummySocket{}
pdata := s.GetData()
ds.streamsStarted = pdata.GetStreamsStarted()
ds.streamsSucceeded = pdata.GetStreamsSucceeded()
ds.streamsFailed = pdata.GetStreamsFailed()
ds.messagesSent = pdata.GetMessagesSent()
ds.messagesReceived = pdata.GetMessagesReceived()
ds.keepAlivesSent = pdata.GetKeepAlivesSent()
if t, err := ptypes.Timestamp(pdata.GetLastLocalStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastLocalStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastRemoteStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastRemoteStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageSentTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageSentTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageReceivedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageReceivedTimestamp = t
}
}
if v := pdata.GetLocalFlowControlWindow(); v != nil {
ds.localFlowControlWindow = v.Value
}
if v := pdata.GetRemoteFlowControlWindow(); v != nil {
ds.remoteFlowControlWindow = v.Value
}
if v := pdata.GetOption(); v != nil && protoToSocketOpt != nil {
ds.socketOptions = protoToSocketOpt(v)
}
if v := s.GetSecurity(); v != nil {
ds.security = protoToSecurity(v)
}
if local := s.GetLocal(); local != nil {
ds.localAddr = protoToAddr(local)
}
if remote := s.GetRemote(); remote != nil {
ds.remoteAddr = protoToAddr(remote)
}
ds.remoteName = s.GetRemoteName()
return ds
}
func protoToSecurity(protoSecurity *channelzpb.Security) credentials.ChannelzSecurityValue {
switch v := protoSecurity.Model.(type) {
case *channelzpb.Security_Tls_:
return &credentials.TLSChannelzSecurityValue{StandardName: v.Tls.GetStandardName(), LocalCertificate: v.Tls.GetLocalCertificate(), RemoteCertificate: v.Tls.GetRemoteCertificate()}
case *channelzpb.Security_Other:
sv := &credentials.OtherChannelzSecurityValue{Name: v.Other.GetName()}
var x ptypes.DynamicAny
if err := ptypes.UnmarshalAny(v.Other.GetValue(), &x); err == nil {
sv.Value = x.Message
}
return sv
}
return nil
}
func protoToAddr(a *channelzpb.Address) net.Addr {
switch v := a.Address.(type) {
case *channelzpb.Address_TcpipAddress:
if port := v.TcpipAddress.GetPort(); port != 0 {
return &net.TCPAddr{IP: v.TcpipAddress.GetIpAddress(), Port: int(port)}
}
return &net.IPAddr{IP: v.TcpipAddress.GetIpAddress()}
case *channelzpb.Address_UdsAddress_:
return &net.UnixAddr{Name: v.UdsAddress.GetFilename(), Net: "unix"}
case *channelzpb.Address_OtherAddress_:
// TODO:
}
return nil
}
func convertSocketRefSliceToMap(sktRefs []*channelzpb.SocketRef) map[int64]string {
m := make(map[int64]string)
for _, sr := range sktRefs {
m[sr.SocketId] = sr.Name
}
return m
}
type OtherSecurityValue struct {
LocalCertificate []byte `protobuf:"bytes,1,opt,name=local_certificate,json=localCertificate,proto3" json:"local_certificate,omitempty"`
RemoteCertificate []byte `protobuf:"bytes,2,opt,name=remote_certificate,json=remoteCertificate,proto3" json:"remote_certificate,omitempty"`
}
func (m *OtherSecurityValue) Reset() { *m = OtherSecurityValue{} }
func (m *OtherSecurityValue) String() string { return proto.CompactTextString(m) }
func (*OtherSecurityValue) ProtoMessage() {}
func init() {
// Ad-hoc registering the proto type here to facilitate UnmarshalAny of OtherSecurityValue.
proto.RegisterType((*OtherSecurityValue)(nil), "grpc.credentials.OtherChannelzSecurityValue")
}
func TestGetTopChannels(t *testing.T) {
tcs := []*dummyChannel{
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Shutdown,
target: "test.channelz:8888",
callsStarted: 0,
callsSucceeded: 0,
callsFailed: 0,
},
{},
}
channelz.NewChannelzStorage()
for _, c := range tcs {
channelz.RegisterChannel(c, 0, "")
}
s := newCZServer()
resp, _ := s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, c := range resp.GetChannel() {
if !reflect.DeepEqual(channelProtoToStruct(c), tcs[i]) {
t.Fatalf("dummyChannel: %d, want: %#v, got: %#v", i, tcs[i], channelProtoToStruct(c))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterChannel(tcs[0], 0, "")
}
resp, _ = s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServers(t *testing.T) {
ss := []*dummyServer{
{
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 0,
callsFailed: 0,
lastCallStartedTimestamp: time.Now().UTC(),
},
}
channelz.NewChannelzStorage()
for _, s := range ss {
channelz.RegisterServer(s, "")
}
svr := newCZServer()
resp, _ := svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, s := range resp.GetServer() {
if !reflect.DeepEqual(serverProtoToStruct(s), ss[i]) {
t.Fatalf("dummyServer: %d, want: %#v, got: %#v", i, ss[i], serverProtoToStruct(s))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterServer(ss[0], "")
}
resp, _ = svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServerSockets(t *testing.T) {
channelz.NewChannelzStorage()
svrID := channelz.RegisterServer(&dummyServer{}, "")
refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"}
ids := make([]int64, 3)
ids[0] = channelz.RegisterListenSocket(&dummySocket{}, svrID, refNames[0])
ids[1] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[1])
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[2])
svr := newCZServer()
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
// GetServerSockets only return normal sockets.
want := map[int64]string{
ids[1]: refNames[1],
ids[2]: refNames[2],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(resp.GetSocketRef()), want) {
t.Fatalf("GetServerSockets want: %#v, got: %#v", want, resp.GetSocketRef())
}
for i := 0; i < 50; i++ {
channelz.RegisterNormalSocket(&dummySocket{}, svrID, "")
}
resp, _ = svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
// This test makes a GetServerSockets with a non-zero start ID, and expect only
// sockets with ID >= the given start ID.
func TestGetServerSocketsNonZeroStartID(t *testing.T) {
channelz.NewChannelzStorage()
svrID := channelz.RegisterServer(&dummyServer{}, "")
refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"}
ids := make([]int64, 3)
ids[0] = channelz.RegisterListenSocket(&dummySocket{}, svrID, refNames[0])
ids[1] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[1])
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[2])
svr := newCZServer()
// Make GetServerSockets with startID = ids[1]+1, so socket-1 won't be
// included in the response.
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: ids[1] + 1})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
// GetServerSockets only return normal socket-2, socket-1 should be
// filtered by start ID.
want := map[int64]string{
ids[2]: refNames[2],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(resp.GetSocketRef()), want) {
t.Fatalf("GetServerSockets want: %#v, got: %#v", want, resp.GetSocketRef())
}
}
func TestGetChannel(t *testing.T) {
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "nested channel 1", "sub channel 2", "nested channel 3"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
})
ids[1] = channelz.RegisterChannel(&dummyChannel{}, ids[0], refNames[1])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[1]),
Severity: channelz.CtINFO,
},
})
ids[2] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[2])
channelz.AddTraceEvent(ids[2], &channelz.TraceEventDesc{
Desc: "SubChannel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("SubChannel(id:%d) created", ids[2]),
Severity: channelz.CtINFO,
},
})
ids[3] = channelz.RegisterChannel(&dummyChannel{}, ids[1], refNames[3])
channelz.AddTraceEvent(ids[3], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[3]),
Severity: channelz.CtINFO,
},
})
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Channel Connectivity change to %v", connectivity.Ready),
Severity: channelz.CtINFO,
})
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Resolver returns an empty address list",
Severity: channelz.CtWarning,
})
svr := newCZServer()
resp, _ := svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[0]})
metrics := resp.GetChannel()
subChans := metrics.GetSubchannelRef()
if len(subChans) != 1 || subChans[0].GetName() != refNames[2] || subChans[0].GetSubchannelId() != ids[2] {
t.Fatalf("metrics.GetSubChannelRef() want %#v, got %#v", []*channelzpb.SubchannelRef{{SubchannelId: ids[2], Name: refNames[2]}}, subChans)
}
nestedChans := metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[1] || nestedChans[0].GetChannelId() != ids[1] {
t.Fatalf("metrics.GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[1], Name: refNames[1]}}, nestedChans)
}
trace := metrics.GetData().GetTrace()
want := []struct {
desc string
severity channelzpb.ChannelTraceEvent_Severity
childID int64
childRef string
}{
{desc: "Channel Created", severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[1]), severity: channelzpb.ChannelTraceEvent_CT_INFO, childID: ids[1], childRef: refNames[1]},
{desc: fmt.Sprintf("SubChannel(id:%d) created", ids[2]), severity: channelzpb.ChannelTraceEvent_CT_INFO, childID: ids[2], childRef: refNames[2]},
{desc: fmt.Sprintf("Channel Connectivity change to %v", connectivity.Ready), severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: "Resolver returns an empty address list", severity: channelzpb.ChannelTraceEvent_CT_WARNING},
}
for i, e := range trace.Events {
if e.GetDescription() != want[i].desc {
t.Fatalf("trace: GetDescription want %#v, got %#v", want[i].desc, e.GetDescription())
}
if e.GetSeverity() != want[i].severity {
t.Fatalf("trace: GetSeverity want %#v, got %#v", want[i].severity, e.GetSeverity())
}
if want[i].childID == 0 && (e.GetChannelRef() != nil || e.GetSubchannelRef() != nil) {
t.Fatalf("trace: GetChannelRef() should return nil, as there is no reference")
}
if e.GetChannelRef().GetChannelId() != want[i].childID || e.GetChannelRef().GetName() != want[i].childRef {
if e.GetSubchannelRef().GetSubchannelId() != want[i].childID || e.GetSubchannelRef().GetName() != want[i].childRef {
t.Fatalf("trace: GetChannelRef/GetSubchannelRef want (child ID: %d, child name: %q), got %#v and %#v", want[i].childID, want[i].childRef, e.GetChannelRef(), e.GetSubchannelRef())
}
}
}
resp, _ = svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[1]})
metrics = resp.GetChannel()
nestedChans = metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[3] || nestedChans[0].GetChannelId() != ids[3] {
t.Fatalf("metrics.GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[3], Name: refNames[3]}}, nestedChans)
}
}
func TestGetSubChannel(t *testing.T) {
var (
subchanCreated = "SubChannel Created"
subchanConnectivityChange = fmt.Sprintf("Subchannel Connectivity change to %v", connectivity.Ready)
subChanPickNewAddress = fmt.Sprintf("Subchannel picks a new address %q to connect", "0.0.0.0")
)
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "sub channel 1", "socket 1", "socket 2"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
})
ids[1] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[1])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subchanCreated,
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[0]),
Severity: channelz.CtINFO,
},
})
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[2])
ids[3] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[3])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subchanConnectivityChange,
Severity: channelz.CtINFO,
})
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subChanPickNewAddress,
Severity: channelz.CtINFO,
})
svr := newCZServer()
resp, _ := svr.GetSubchannel(context.Background(), &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
metrics := resp.GetSubchannel()
want := map[int64]string{
ids[2]: refNames[2],
ids[3]: refNames[3],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(metrics.GetSocketRef()), want) {
t.Fatalf("metrics.GetSocketRef() want %#v: got: %#v", want, metrics.GetSocketRef())
}
trace := metrics.GetData().GetTrace()
wantTrace := []struct {
desc string
severity channelzpb.ChannelTraceEvent_Severity
childID int64
childRef string
}{
{desc: subchanCreated, severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: subchanConnectivityChange, severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: subChanPickNewAddress, severity: channelzpb.ChannelTraceEvent_CT_INFO},
}
for i, e := range trace.Events {
if e.GetDescription() != wantTrace[i].desc {
t.Fatalf("trace: GetDescription want %#v, got %#v", wantTrace[i].desc, e.GetDescription())
}
if e.GetSeverity() != wantTrace[i].severity {
t.Fatalf("trace: GetSeverity want %#v, got %#v", wantTrace[i].severity, e.GetSeverity())
}
if wantTrace[i].childID == 0 && (e.GetChannelRef() != nil || e.GetSubchannelRef() != nil) {
t.Fatalf("trace: GetChannelRef() should return nil, as there is no reference")
}
if e.GetChannelRef().GetChannelId() != wantTrace[i].childID || e.GetChannelRef().GetName() != wantTrace[i].childRef {
if e.GetSubchannelRef().GetSubchannelId() != wantTrace[i].childID || e.GetSubchannelRef().GetName() != wantTrace[i].childRef {
t.Fatalf("trace: GetChannelRef/GetSubchannelRef want (child ID: %d, child name: %q), got %#v and %#v", wantTrace[i].childID, wantTrace[i].childRef, e.GetChannelRef(), e.GetSubchannelRef())
}
}
}
}
func TestGetSocket(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.TCPAddr{IP: net.ParseIP("1.0.0.1"), Port: 10001},
remoteAddr: &net.TCPAddr{IP: net.ParseIP("12.0.0.1"), Port: 10002},
remoteName: "remote.remote",
},
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.UnixAddr{Name: "file.path", Net: "unix"},
remoteAddr: &net.UnixAddr{Name: "another.path", Net: "unix"},
remoteName: "remote.remote",
},
{
streamsStarted: 5,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 10240,
localAddr: &net.IPAddr{IP: net.ParseIP("1.0.0.1")},
remoteAddr: &net.IPAddr{IP: net.ParseIP("9.0.0.1")},
remoteName: "",
},
{
localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 10001},
},
{
security: &credentials.TLSChannelzSecurityValue{
StandardName: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
RemoteCertificate: []byte{48, 130, 2, 156, 48, 130, 2, 5, 160},
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "XXXX",
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "YYYY",
Value: &OtherSecurityValue{LocalCertificate: []byte{1, 2, 3}, RemoteCertificate: []byte{4, 5, 6}},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View File

@ -1,33 +0,0 @@
// +build 386,linux,!appengine
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
sec, usec := convertToDuration(protoTime.GetDuration())
timeout.Sec, timeout.Usec = int32(sec), int32(usec)
return timeout
}

View File

@ -1,32 +0,0 @@
// +build amd64,linux,!appengine
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
timeout.Sec, timeout.Usec = convertToDuration(protoTime.GetDuration())
return timeout
}

View File

@ -1,515 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/http2"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const stateRecordingBalancerName = "state_recoding_balancer"
var testBalancer = &stateRecordingBalancer{}
func init() {
balancer.Register(testBalancer)
}
// These tests use a pipeListener. This listener is similar to net.Listener except that it is unbuffered, so each read
// and write will wait for the other side's corresponding write or read.
func TestStateTransitions_SingleAddress(t *testing.T) {
defer leakcheck.Check(t)
mctBkp := getMinConnectTimeout()
defer func() {
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
}()
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*100)
for _, test := range []struct {
desc string
want []connectivity.State
server func(net.Listener) net.Conn
}{
{
desc: "When the server returns server preface, the client enters READY.",
want: []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return nil
}
return conn
},
},
{
desc: "When the connection is closed, the client enters TRANSIENT FAILURE.",
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
conn.Close()
return nil
},
},
{
desc: `When the server sends its connection preface, but the connection dies before the client can write its
connection preface, the client enters TRANSIENT FAILURE.`,
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return nil
}
conn.Close()
return nil
},
},
{
desc: `When the server reads the client connection preface but does not send its connection preface, the
client enters TRANSIENT FAILURE.`,
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
go keepReading(conn)
return conn
},
},
} {
t.Log(test.desc)
testStateTransitionSingleAddress(t, test.want, test.server)
}
}
func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener) net.Conn) {
defer leakcheck.Check(t)
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
pl := testutils.NewPipeListener()
defer pl.Close()
// Launch the server.
var conn net.Conn
var connMu sync.Mutex
go func() {
connMu.Lock()
conn = server(pl)
connMu.Unlock()
}()
client, err := DialContext(ctx, "", WithWaitForHandshake(), WithInsecure(),
WithBalancerName(stateRecordingBalancerName), WithDialer(pl.Dialer()), withBackoff(noBackoff{}))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
connMu.Lock()
defer connMu.Unlock()
if conn != nil {
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}
}
// When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING.
func TestStateTransition_ReadyToTransientFailure(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.TransientFailure,
connectivity.Connecting,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
sawReady := make(chan struct{})
// Launch the server.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
// Prevents race between onPrefaceReceipt and onClose.
<-sawReady
conn.Close()
}()
client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen == connectivity.Ready {
close(sawReady)
}
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
}
// When the first connection is closed, the client enters stays in CONNECTING until it tries the second
// address (which succeeds, and then it enters READY).
func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis1, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis1.Close()
lis2, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis2.Close()
server1Done := make(chan struct{})
server2Done := make(chan struct{})
// Launch server 1.
go func() {
conn, err := lis1.Accept()
if err != nil {
t.Error(err)
return
}
conn.Close()
close(server1Done)
}()
// Launch server 2.
go func() {
conn, err := lis2.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
close(server2Done)
}()
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialAddrs([]resolver.Address{
{Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()},
})
client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
case <-server1Done:
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 2")
case <-server2Done:
}
}
// When there are multiple addresses, and we enter READY on one of them, a later closure should cause
// the client to enter TRANSIENT FAILURE before it re-enters CONNECTING.
func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.TransientFailure,
connectivity.Connecting,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis1, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis1.Close()
// Never actually gets used; we just want it to be alive so that the resolver has two addresses to target.
lis2, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis2.Close()
server1Done := make(chan struct{})
sawReady := make(chan struct{})
// Launch server 1.
go func() {
conn, err := lis1.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
<-sawReady
conn.Close()
_, err = lis1.Accept()
if err != nil {
t.Error(err)
return
}
close(server1Done)
}()
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialAddrs([]resolver.Address{
{Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()},
})
client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen == connectivity.Ready {
close(sawReady)
}
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
case <-server1Done:
}
}
type stateRecordingBalancer struct {
mu sync.Mutex
notifier chan<- connectivity.State
balancer.Balancer
}
func (b *stateRecordingBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
b.mu.Lock()
b.notifier <- s
b.mu.Unlock()
b.Balancer.HandleSubConnStateChange(sc, s)
}
func (b *stateRecordingBalancer) ResetNotifier(r chan<- connectivity.State) {
b.mu.Lock()
defer b.mu.Unlock()
b.notifier = r
}
func (b *stateRecordingBalancer) Close() {
b.mu.Lock()
u := b.Balancer
b.mu.Unlock()
u.Close()
}
func (b *stateRecordingBalancer) Name() string {
return stateRecordingBalancerName
}
func (b *stateRecordingBalancer) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
b.mu.Lock()
b.Balancer = balancer.Get(PickFirstBalancerName).Build(cc, opts)
b.mu.Unlock()
return b
}
type noBackoff struct{}
func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) }
// Keep reading until something causes the connection to die (EOF, server closed, etc). Useful
// as a tool for mindlessly keeping the connection healthy, since the client will error if
// things like client prefaces are not accepted in a timely fashion.
func keepReading(conn net.Conn) {
buf := make([]byte, 1024)
for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) {
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,32 +0,0 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"testing"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
)
func TestGetCodecForProtoIsNotNil(t *testing.T) {
if encoding.GetCodec(proto.Name) == nil {
t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name)
}
}

View File

@ -1,17 +0,0 @@
#!/usr/bin/env bash
# This script serves as an example to demonstrate how to generate the gRPC-Go
# interface and the related messages from .proto file.
#
# It assumes the installation of i) Google proto buffer compiler at
# https://github.com/google/protobuf (after v2.6.1) and ii) the Go codegen
# plugin at https://github.com/golang/protobuf (after 2015-02-20). If you have
# not, please install them first.
#
# We recommend running this script at $GOPATH/src.
#
# If this is not what you need, feel free to make your own scripts. Again, this
# script is for demonstration purpose.
#
proto=$1
protoc --go_out=plugins=grpc:. $proto

View File

@ -1,84 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package codes
import (
"encoding/json"
"reflect"
"testing"
cpb "google.golang.org/genproto/googleapis/rpc/code"
)
func TestUnmarshalJSON(t *testing.T) {
for s, v := range cpb.Code_value {
want := Code(v)
var got Code
if err := got.UnmarshalJSON([]byte(`"` + s + `"`)); err != nil || got != want {
t.Errorf("got.UnmarshalJSON(%q) = %v; want <nil>. got=%v; want %v", s, err, got, want)
}
}
}
func TestJSONUnmarshal(t *testing.T) {
var got []Code
want := []Code{OK, NotFound, Internal, Canceled}
in := `["OK", "NOT_FOUND", "INTERNAL", "CANCELLED"]`
err := json.Unmarshal([]byte(in), &got)
if err != nil || !reflect.DeepEqual(got, want) {
t.Fatalf("json.Unmarshal(%q, &got) = %v; want <nil>. got=%v; want %v", in, err, got, want)
}
}
func TestUnmarshalJSON_NilReceiver(t *testing.T) {
var got *Code
in := OK.String()
if err := got.UnmarshalJSON([]byte(in)); err == nil {
t.Errorf("got.UnmarshalJSON(%q) = nil; want <non-nil>. got=%v", in, got)
}
}
func TestUnmarshalJSON_UnknownInput(t *testing.T) {
var got Code
for _, in := range [][]byte{[]byte(""), []byte("xxx"), []byte("Code(17)"), nil} {
if err := got.UnmarshalJSON([]byte(in)); err == nil {
t.Errorf("got.UnmarshalJSON(%q) = nil; want <non-nil>. got=%v", in, got)
}
}
}
func TestUnmarshalJSON_MarshalUnmarshal(t *testing.T) {
for i := 0; i < _maxCode; i++ {
var cUnMarshaled Code
c := Code(i)
cJSON, err := json.Marshal(c)
if err != nil {
t.Errorf("marshalling %q failed: %v", c, err)
}
if err := json.Unmarshal(cJSON, &cUnMarshaled); err != nil {
t.Errorf("unmarshalling code failed: %s", err)
}
if c != cUnMarshaled {
t.Errorf("code is %q after marshalling/unmarshalling, expected %q", cUnMarshaled, c)
}
}
}

View File

@ -1,330 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package alts implements the ALTS credential support by gRPC library, which
// encapsulates all the state needed by a client to authenticate with a server
// using ALTS and make various assertions, e.g., about the client's identity,
// role, or whether it is authorized to make a particular call.
// This package is experimental.
package alts
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/handshaker"
"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/grpclog"
)
const (
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
// handshaker service address in the hypervisor.
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
// defaultTimeout specifies the server handshake timeout.
defaultTimeout = 30.0 * time.Second
// The following constants specify the minimum and maximum acceptable
// protocol versions.
protocolVersionMaxMajor = 2
protocolVersionMaxMinor = 1
protocolVersionMinMajor = 2
protocolVersionMinMinor = 1
)
var (
once sync.Once
maxRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
}
minRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
}
// ErrUntrustedPlatform is returned from ClientHandshake and
// ServerHandshake is running on a platform where the trustworthiness of
// the handshaker service is not guaranteed.
ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
)
// AuthInfo exposes security information from the ALTS handshake to the
// application. This interface is to be implemented by ALTS. Users should not
// need a brand new implementation of this interface. For situations like
// testing, any new implementation should embed this interface. This allows
// ALTS to add new methods to this interface.
type AuthInfo interface {
// ApplicationProtocol returns application protocol negotiated for the
// ALTS connection.
ApplicationProtocol() string
// RecordProtocol returns the record protocol negotiated for the ALTS
// connection.
RecordProtocol() string
// SecurityLevel returns the security level of the created ALTS secure
// channel.
SecurityLevel() altspb.SecurityLevel
// PeerServiceAccount returns the peer service account.
PeerServiceAccount() string
// LocalServiceAccount returns the local service account.
LocalServiceAccount() string
// PeerRPCVersions returns the RPC version supported by the peer.
PeerRPCVersions() *altspb.RpcProtocolVersions
}
// ClientOptions contains the client-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ClientOptions struct {
// TargetServiceAccounts contains a list of expected target service
// accounts.
TargetServiceAccounts []string
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultClientOptions creates a new ClientOptions object with the default
// values.
func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// ServerOptions contains the server-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultServerOptions creates a new ServerOptions object with the default
// values.
func DefaultServerOptions() *ServerOptions {
return &ServerOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
side core.Side
accounts []string
hsAddress string
}
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
}
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
}
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
once.Do(func() {
vmOnGCP = isRunningOnGCP()
})
if hsAddress == "" {
hsAddress = hypervisorHandshakerServiceAddress
}
return &altsTC{
info: &credentials.ProtocolInfo{
SecurityProtocol: "alts",
SecurityVersion: "1.0",
},
side: side,
accounts: accounts,
hsAddress: hsAddress,
}
}
// ClientHandshake implements the client side handshake protocol.
func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it is shared with other handshakes.
// Possible context leak:
// The cancel function for the child context we create will only be
// called a non-nil error is returned.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
opts := handshaker.DefaultClientHandshakerOptions()
opts.TargetName = addr
opts.TargetServiceAccounts = g.accounts
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
chs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := chs.ClientHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
// ServerHandshake implements the server side ALTS handshaker.
func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it's shared with other handshakes.
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
opts := handshaker.DefaultServerHandshakerOptions()
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
shs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := shs.ServerHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
func (g *altsTC) Info() credentials.ProtocolInfo {
return *g.info
}
func (g *altsTC) Clone() credentials.TransportCredentials {
info := *g.info
var accounts []string
if g.accounts != nil {
accounts = make([]string, len(g.accounts))
copy(accounts, g.accounts)
}
return &altsTC{
info: &info,
side: g.side,
hsAddress: g.hsAddress,
accounts: accounts,
}
}
func (g *altsTC) OverrideServerName(serverNameOverride string) error {
g.info.ServerName = serverNameOverride
return nil
}
// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
switch {
case v1.GetMajor() > v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
return 1
case v1.GetMajor() < v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
return -1
}
return 0
}
// checkRPCVersions performs a version check between local and peer rpc protocol
// versions. This function returns true if the check passes which means both
// parties agreed on a common rpc protocol to use, and false otherwise. The
// function also returns the highest common RPC protocol version both parties
// agreed on.
func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
if local == nil || peer == nil {
grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
return false, nil
}
// maxCommonVersion is MIN(local.max, peer.max).
maxCommonVersion := local.GetMaxRpcVersion()
if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
maxCommonVersion = peer.GetMaxRpcVersion()
}
// minCommonVersion is MAX(local.min, peer.min).
minCommonVersion := peer.GetMinRpcVersion()
if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
minCommonVersion = local.GetMinRpcVersion()
}
if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
return false, nil
}
return true, maxCommonVersion
}

View File

@ -1,290 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"reflect"
"testing"
"github.com/golang/protobuf/proto"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
func TestInfoServerName(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
alts := NewServerCreds(DefaultServerOptions())
if got, want := alts.Info().ServerName, ""; got != want {
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
}
}
func TestOverrideServerName(t *testing.T) {
wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
}
}
func TestCloneClient(t *testing.T) {
wantServerName := "server.name"
opt := DefaultClientOptions()
opt.TargetServiceAccounts = []string{"not", "empty"}
c := NewClientCreds(opt)
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestCloneServer(t *testing.T) {
wantServerName := "server.name"
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestInfo(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
info := c.Info()
if got, want := info.ProtocolVersion, ""; got != want {
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
}
if got, want := info.SecurityProtocol, "alts"; got != want {
t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
}
if got, want := info.SecurityVersion, "1.0"; got != want {
t.Errorf("info.SecurityVersion=%v, want %v", got, want)
}
if got, want := info.ServerName, ""; got != want {
t.Errorf("info.ServerName=%v, want %v", got, want)
}
}
func TestCompareRPCVersions(t *testing.T) {
for _, tc := range []struct {
v1 *altspb.RpcProtocolVersions_Version
v2 *altspb.RpcProtocolVersions_Version
output int
}{
{
version(3, 2),
version(2, 1),
1,
},
{
version(3, 2),
version(3, 1),
1,
},
{
version(2, 1),
version(3, 2),
-1,
},
{
version(3, 1),
version(3, 2),
-1,
},
{
version(3, 2),
version(3, 2),
0,
},
} {
if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
}
}
}
func TestCheckRPCVersions(t *testing.T) {
for _, tc := range []struct {
desc string
local *altspb.RpcProtocolVersions
peer *altspb.RpcProtocolVersions
output bool
maxCommonVersion *altspb.RpcProtocolVersions_Version
}{
{
"local.max > peer.max and local.min > peer.min",
versions(2, 1, 3, 2),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min < peer.min",
versions(1, 2, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min = peer.min",
versions(2, 1, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 3, 2),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min = peer.min",
versions(1, 2, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"all equal",
versions(2, 1, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"max is smaller than min",
versions(2, 1, 1, 2),
versions(2, 1, 1, 2),
false,
nil,
},
{
"no overlap, local > peer",
versions(4, 3, 6, 5),
versions(1, 0, 2, 1),
false,
nil,
},
{
"no overlap, local < peer",
versions(1, 0, 2, 1),
versions(4, 3, 6, 5),
false,
nil,
},
{
"no overlap, max < min",
versions(6, 5, 4, 3),
versions(2, 1, 1, 0),
false,
nil,
},
} {
output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
if got, want := output, tc.output; got != want {
t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
}
if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
}
}
}
func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
return &altspb.RpcProtocolVersions_Version{
Major: major,
Minor: minor,
}
}
func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
return &altspb.RpcProtocolVersions{
MinRpcVersion: version(minMajor, minMinor),
MaxRpcVersion: version(maxMajor, maxMinor),
}
}

View File

@ -1,87 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package authinfo provide authentication information returned by handshakers.
package authinfo
import (
"google.golang.org/grpc/credentials"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
var _ credentials.AuthInfo = (*altsAuthInfo)(nil)
// altsAuthInfo exposes security information from the ALTS handshake to the
// application. altsAuthInfo is immutable and implements credentials.AuthInfo.
type altsAuthInfo struct {
p *altspb.AltsContext
}
// New returns a new altsAuthInfo object given handshaker results.
func New(result *altspb.HandshakerResult) credentials.AuthInfo {
return newAuthInfo(result)
}
func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo {
return &altsAuthInfo{
p: &altspb.AltsContext{
ApplicationProtocol: result.GetApplicationProtocol(),
RecordProtocol: result.GetRecordProtocol(),
// TODO: assign security level from result.
SecurityLevel: altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
PeerServiceAccount: result.GetPeerIdentity().GetServiceAccount(),
LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(),
PeerRpcVersions: result.GetPeerRpcVersions(),
},
}
}
// AuthType identifies the context as providing ALTS authentication information.
func (s *altsAuthInfo) AuthType() string {
return "alts"
}
// ApplicationProtocol returns the context's application protocol.
func (s *altsAuthInfo) ApplicationProtocol() string {
return s.p.GetApplicationProtocol()
}
// RecordProtocol returns the context's record protocol.
func (s *altsAuthInfo) RecordProtocol() string {
return s.p.GetRecordProtocol()
}
// SecurityLevel returns the context's security level.
func (s *altsAuthInfo) SecurityLevel() altspb.SecurityLevel {
return s.p.GetSecurityLevel()
}
// PeerServiceAccount returns the context's peer service account.
func (s *altsAuthInfo) PeerServiceAccount() string {
return s.p.GetPeerServiceAccount()
}
// LocalServiceAccount returns the context's local service account.
func (s *altsAuthInfo) LocalServiceAccount() string {
return s.p.GetLocalServiceAccount()
}
// PeerRPCVersions returns the context's peer RPC versions.
func (s *altsAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions {
return s.p.GetPeerRpcVersions()
}

View File

@ -1,134 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package authinfo
import (
"reflect"
"testing"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
testAppProtocol = "my_app"
testRecordProtocol = "very_secure_protocol"
testPeerAccount = "peer_service_account"
testLocalAccount = "local_service_account"
testPeerHostname = "peer_hostname"
testLocalHostname = "local_hostname"
)
func TestALTSAuthInfo(t *testing.T) {
for _, tc := range []struct {
result *altspb.HandshakerResult
outAppProtocol string
outRecordProtocol string
outSecurityLevel altspb.SecurityLevel
outPeerAccount string
outLocalAccount string
outPeerRPCVersions *altspb.RpcProtocolVersions
}{
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testPeerAccount,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testLocalAccount,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
testPeerAccount,
testLocalAccount,
nil,
},
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testPeerHostname,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testLocalHostname,
},
},
PeerRpcVersions: &altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
"",
"",
&altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
} {
authInfo := newAuthInfo(tc.result)
if got, want := authInfo.AuthType(), "alts"; got != want {
t.Errorf("authInfo.AuthType()=%v, want %v", got, want)
}
if got, want := authInfo.ApplicationProtocol(), tc.outAppProtocol; got != want {
t.Errorf("authInfo.ApplicationProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.RecordProtocol(), tc.outRecordProtocol; got != want {
t.Errorf("authInfo.RecordProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.SecurityLevel(), tc.outSecurityLevel; got != want {
t.Errorf("authInfo.SecurityLevel()=%v, want %v", got, want)
}
if got, want := authInfo.PeerServiceAccount(), tc.outPeerAccount; got != want {
t.Errorf("authInfo.PeerServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.LocalServiceAccount(), tc.outLocalAccount; got != want {
t.Errorf("authInfo.LocalServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.PeerRPCVersions(), tc.outPeerRPCVersions; !reflect.DeepEqual(got, want) {
t.Errorf("authinfo.PeerRpcVersions()=%v, want %v", got, want)
}
}
}

View File

@ -1,69 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package internal contains common core functionality for ALTS.
package internal
import (
"context"
"net"
"google.golang.org/grpc/credentials"
)
const (
// ClientSide identifies the client in this communication.
ClientSide Side = iota
// ServerSide identifies the server in this communication.
ServerSide
)
// PeerNotRespondingError is returned when a peer server is not responding
// after a channel has been established. It is treated as a temporary connection
// error and re-connection to the server should be attempted.
var PeerNotRespondingError = &peerNotRespondingError{}
// Side identifies the party's role: client or server.
type Side int
type peerNotRespondingError struct{}
// Return an error message for the purpose of logging.
func (e *peerNotRespondingError) Error() string {
return "peer server is not responding and re-connection should be attempted."
}
// Temporary indicates if this connection error is temporary or fatal.
func (e *peerNotRespondingError) Temporary() bool {
return true
}
// Handshaker defines a ALTS handshaker interface.
type Handshaker interface {
// ClientHandshake starts and completes a client-side handshaking and
// returns a secure connection and corresponding auth information.
ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// ServerHandshake starts and completes a server-side handshaking and
// returns a secure connection and corresponding auth information.
ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// Close terminates the Handshaker. It should be called when the caller
// obtains the secure connection.
Close()
}

View File

@ -1,131 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
"strconv"
)
// rekeyAEAD holds the necessary information for an AEAD based on
// AES-GCM that performs nonce-based key derivation and XORs the
// nonce with a random mask.
type rekeyAEAD struct {
kdfKey []byte
kdfCounter []byte
nonceMask []byte
nonceBuf []byte
gcmAEAD cipher.AEAD
}
// KeySizeError signals that the given key does not have the correct size.
type KeySizeError int
func (k KeySizeError) Error() string {
return "alts/conn: invalid key size " + strconv.Itoa(int(k))
}
// newRekeyAEAD creates a new instance of aes128gcm with rekeying.
// The key argument should be 44 bytes, the first 32 bytes are used as a key
// for HKDF-expand and the remainining 12 bytes are used as a random mask for
// the counter.
func newRekeyAEAD(key []byte) (*rekeyAEAD, error) {
k := len(key)
if k != kdfKeyLen+nonceLen {
return nil, KeySizeError(k)
}
return &rekeyAEAD{
kdfKey: key[:kdfKeyLen],
kdfCounter: make([]byte, kdfCounterLen),
nonceMask: key[kdfKeyLen:],
nonceBuf: make([]byte, nonceLen),
gcmAEAD: nil,
}, nil
}
// Seal rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Seal for aes128gcm.
func (s *rekeyAEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if err := s.rekeyIfRequired(nonce); err != nil {
panic(fmt.Sprintf("Rekeying failed with: %s", err.Error()))
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Seal(dst, s.nonceBuf, plaintext, additionalData)
}
// Open rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Open for aes128gcm.
func (s *rekeyAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if err := s.rekeyIfRequired(nonce); err != nil {
return nil, err
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Open(dst, s.nonceBuf, ciphertext, additionalData)
}
// rekeyIfRequired creates a new aes128gcm AEAD if the existing AEAD is nil
// or cannot be used with given nonce.
func (s *rekeyAEAD) rekeyIfRequired(nonce []byte) error {
newKdfCounter := nonce[kdfCounterOffset : kdfCounterOffset+kdfCounterLen]
if s.gcmAEAD != nil && bytes.Equal(newKdfCounter, s.kdfCounter) {
return nil
}
copy(s.kdfCounter, newKdfCounter)
a, err := aes.NewCipher(hkdfExpand(s.kdfKey, s.kdfCounter))
if err != nil {
return err
}
s.gcmAEAD, err = cipher.NewGCM(a)
return err
}
// maskNonce XORs the given nonce with the mask and stores the result in dst.
func maskNonce(dst, nonce, mask []byte) {
nonce1 := binary.LittleEndian.Uint64(nonce[:sizeUint64])
nonce2 := binary.LittleEndian.Uint32(nonce[sizeUint64:])
mask1 := binary.LittleEndian.Uint64(mask[:sizeUint64])
mask2 := binary.LittleEndian.Uint32(mask[sizeUint64:])
binary.LittleEndian.PutUint64(dst[:sizeUint64], nonce1^mask1)
binary.LittleEndian.PutUint32(dst[sizeUint64:], nonce2^mask2)
}
// NonceSize returns the required nonce size.
func (s *rekeyAEAD) NonceSize() int {
return s.gcmAEAD.NonceSize()
}
// Overhead returns the ciphertext overhead.
func (s *rekeyAEAD) Overhead() int {
return s.gcmAEAD.Overhead()
}
// hkdfExpand computes the first 16 bytes of the HKDF-expand function
// defined in RFC5869.
func hkdfExpand(key, info []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(info)
mac.Write([]byte{0x01}[:])
return mac.Sum(nil)[:aeadKeyLen]
}

View File

@ -1,263 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/hex"
"testing"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyAEADTestVector struct {
desc string
key, nonce, plaintext, aad, ciphertext []byte
}
// Test encrypt and decrypt using (adapted) test vectors for AES-GCM.
func TestAES128GCMRekeyEncrypt(t *testing.T) {
for _, test := range []rekeyAEADTestVector{
// NIST vectors from:
// http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
//
// IEEE vectors from:
// http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf
//
// Key expanded by setting
// expandedKey = (key ||
// key ^ {0x01,..,0x01} ||
// key ^ {0x02,..,0x02})[0:44].
{
desc: "Derived from NIST test vector 1",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex(""),
ciphertext: dehex("85e873e002f6ebdc4060954eb8675508"),
},
{
desc: "Derived from NIST test vector 2",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex("00000000000000000000000000000000"),
ciphertext: dehex("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11"),
},
{
desc: "Derived from NIST test vector 3",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex(""),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f3d"),
},
{
desc: "Derived from NIST test vector 4",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 15)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("ca7ebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 16)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebbbefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab463bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 63)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedb2ddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 64)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddfcaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78"),
},
{
desc: "Derived from IEEE 2.1.1 54-byte auth",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("3ea0b584f3c85e93f9320ea591699efb"),
},
{
desc: "Derived from IEEE 2.1.2 54-byte auth",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("294e028bf1fe6f14c4e8f7305c933eb5"),
},
{
desc: "Derived from IEEE 2.2.1 60-byte crypt",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6"),
},
{
desc: "Derived from IEEE 2.2.2 60-byte crypt",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55b8f8133e77f659132502bad63f5713d57d0c11e0f871ed"),
},
{
desc: "Derived from IEEE 2.3.1 60-byte auth",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("58837a10562b0f1f8edbe58ca55811d3"),
},
{
desc: "Derived from IEEE 2.3.2 60-byte auth",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("c2722ff6ca29a257718a529d1f0c6a3b"),
},
{
desc: "Derived from IEEE 2.4.1 54-byte crypt",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e0528a1f4c644a949fc31dc9279706ddba33f"),
},
{
desc: "Derived from IEEE 2.4.2 54-byte crypt",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4f878e6cc3577768085c50a0e98fda6cbb8"),
},
{
desc: "Derived from IEEE 2.5.1 65-byte auth",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("cca20eecda6283f09bb3543dd99edb9b"),
},
{
desc: "Derived from IEEE 2.5.2 65-byte auth",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("b232cc1da5117bf15003734fa599d271"),
},
{
desc: "Derived from IEEE 2.6.1 61-byte crypt",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c3532ecd5d09a056d70be583f0d10829d9387d07d33d872e490"),
},
{
desc: "Derived from IEEE 2.6.2 61-byte crypt",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4"),
},
{
desc: "Derived from IEEE 2.7.1 79-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("813f0e630f96fb2d030f58d83f5cdfd0"),
},
{
desc: "Derived from IEEE 2.7.2 79-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("77e5a44c21eb07188aacbd74d1980e97"),
},
{
desc: "Derived from IEEE 2.8.1 61-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb5939462541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821f4"),
},
{
desc: "Derived from IEEE 2.8.2 61-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b00"),
}} {
aead, err := newRekeyAEAD(test.key)
if err != nil {
t.Fatal("unexpected failure in newRekeyAEAD: ", err.Error())
}
if got := aead.Seal(nil, test.nonce, test.plaintext, test.aad); !bytes.Equal(got, test.ciphertext) {
t.Errorf("Unexpected ciphertext for test vector '%s':\nciphertext=%s\nwant= %s",
test.desc, hex.EncodeToString(got), hex.EncodeToString(test.ciphertext))
}
if got, err := aead.Open(nil, test.nonce, test.ciphertext, test.aad); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("Unexpected plaintext for test vector '%s':\nplaintext=%s (err=%v)\nwant= %s",
test.desc, hex.EncodeToString(got), err, hex.EncodeToString(test.plaintext))
}
}
}
func dehex(s string) []byte {
if len(s) == 0 {
return make([]byte, 0)
}
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

View File

@ -1,105 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/aes"
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCM = 5
)
// aes128gcm is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcm struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
aead cipher.AEAD
}
// NewAES128GCM creates an instance that uses aes128gcm for ALTS record.
func NewAES128GCM(side core.Side, key []byte) (ALTSRecordCrypto, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
a, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &aes128gcm{
inCounter: NewInCounter(side, overflowLenAES128GCM),
outCounter: NewOutCounter(side, overflowLenAES128GCM),
aead: a,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcm) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.aead.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcm) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcm) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
// If dst is equal to ciphertext[:0], ciphertext storage is reused.
plaintext, err := s.aead.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View File

@ -1,223 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a GCM test vector
type cryptoTestVector struct {
key, counter, plaintext, ciphertext, tag []byte
allocateDst bool
}
// getGCMCryptoPair outputs a client/server pair on aes128gcm.
func getGCMCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCM(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ClientSide, key) = %v", err)
}
server, err := NewAES128GCM(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
server.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
} else {
server.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
client.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
}
}
return client, server
}
func testGCMEncryptionDecryption(sender ALTSRecordCrypto, receiver ALTSRecordCrypto, test *cryptoTestVector, withCounter bool, t *testing.T) {
// Ciphertext is: counter + encrypted text + tag.
ciphertext := []byte(nil)
if withCounter {
ciphertext = append(ciphertext, test.counter...)
}
ciphertext = append(ciphertext, test.ciphertext...)
ciphertext = append(ciphertext, test.tag...)
// Decrypt.
if got, err := receiver.Decrypt(nil, ciphertext); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("key=%v\ncounter=%v\ntag=%v\nciphertext=%v\nDecrypt = %v, %v\nwant: %v",
test.key, test.counter, test.tag, test.ciphertext, got, err, test.plaintext)
}
// Encrypt.
var dst []byte
if test.allocateDst {
dst = make([]byte, len(test.plaintext)+sender.EncryptionOverhead())
}
if got, err := sender.Encrypt(dst[:0], test.plaintext); err != nil || !bytes.Equal(got, ciphertext) {
t.Errorf("key=%v\ncounter=%v\nplaintext=%v\nEncrypt = %v, %v\nwant: %v",
test.key, test.counter, test.plaintext, got, err, ciphertext)
}
}
// Test encrypt and decrypt using test vectors for aes128gcm.
func TestAES128GCMEncrypt(t *testing.T) {
for _, test := range []cryptoTestVector{
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: false,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: false,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: false,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: false,
},
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: true,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: true,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: true,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: true,
},
} {
// Test encryption and decryption for aes128gcm.
client, server := getGCMCryptoPair(test.key, test.counter, t)
if CounterSide(test.counter) == core.ClientSide {
testGCMEncryptionDecryption(client, server, &test, false, t)
} else {
testGCMEncryptionDecryption(server, client, &test, false, t)
}
}
}
func testGCMEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcm.
func TestAES128GCMEncryptRoundtrip(t *testing.T) {
// Test for aes128gcm.
key := make([]byte, 16)
client, server := getGCMCryptoPair(key, nil, t)
testGCMEncryptRoundtrip(client, server, t)
}

View File

@ -1,116 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCMRekey = 8
nonceLen = 12
aeadKeyLen = 16
kdfKeyLen = 32
kdfCounterOffset = 2
kdfCounterLen = 6
sizeUint64 = 8
)
// aes128gcmRekey is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcmRekey struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
inAEAD cipher.AEAD
outAEAD cipher.AEAD
}
// NewAES128GCMRekey creates an instance that uses aes128gcm with rekeying
// for ALTS record. The key argument should be 44 bytes, the first 32 bytes
// are used as a key for HKDF-expand and the remainining 12 bytes are used
// as a random mask for the counter.
func NewAES128GCMRekey(side core.Side, key []byte) (ALTSRecordCrypto, error) {
inCounter := NewInCounter(side, overflowLenAES128GCMRekey)
outCounter := NewOutCounter(side, overflowLenAES128GCMRekey)
inAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
outAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
return &aes128gcmRekey{
inCounter,
outCounter,
inAEAD,
outAEAD,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcmRekey) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.outAEAD.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcmRekey) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcmRekey) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
plaintext, err := s.inAEAD.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View File

@ -1,117 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyTestVector struct {
key, nonce, plaintext, ciphertext []byte
}
// getGCMCryptoPair outputs a client/server pair on aes128gcmRekey.
func getRekeyCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCMRekey(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ClientSide, key) = %v", err)
}
server, err := NewAES128GCMRekey(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
server.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
} else {
server.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
client.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
}
}
return client, server
}
func testRekeyEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcmRekey.
func TestAES128GCMRekeyEncryptRoundtrip(t *testing.T) {
// Test for aes128gcmRekey.
key := make([]byte, 44)
client, server := getRekeyCryptoPair(key, nil, t)
testRekeyEncryptRoundtrip(client, server, t)
}

View File

@ -1,70 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"encoding/binary"
"errors"
"fmt"
)
const (
// GcmTagSize is the GCM tag size is the difference in length between
// plaintext and ciphertext. From crypto/cipher/gcm.go in Go crypto
// library.
GcmTagSize = 16
)
// ErrAuth occurs on authentication failure.
var ErrAuth = errors.New("message authentication failed")
// SliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func SliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return head, tail
}
// ParseFramedMsg parse the provided buffer and returns a frame of the format
// msgLength+msg and any remaining bytes in that buffer.
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
if len(b) < MsgLenFieldSize {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+4 { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, b, nil
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

View File

@ -1,62 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"errors"
)
const counterLen = 12
var (
errInvalidCounter = errors.New("invalid counter")
)
// Counter is a 96-bit, little-endian counter.
type Counter struct {
value [counterLen]byte
invalid bool
overflowLen int
}
// Value returns the current value of the counter as a byte slice.
func (c *Counter) Value() ([]byte, error) {
if c.invalid {
return nil, errInvalidCounter
}
return c.value[:], nil
}
// Inc increments the counter and checks for overflow.
func (c *Counter) Inc() {
// If the counter is already invalid, there is no need to increase it.
if c.invalid {
return
}
i := 0
for ; i < c.overflowLen; i++ {
c.value[i]++
if c.value[i] != 0 {
break
}
}
if i == c.overflowLen {
c.invalid = true
}
}

View File

@ -1,141 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
testOverflowLen = 5
)
func TestCounterSides(t *testing.T) {
for _, side := range []core.Side{core.ClientSide, core.ServerSide} {
outCounter := NewOutCounter(side, testOverflowLen)
inCounter := NewInCounter(side, testOverflowLen)
for i := 0; i < 1024; i++ {
value, _ := outCounter.Value()
if g, w := CounterSide(value), side; g != w {
t.Errorf("after %d iterations, CounterSide(outCounter.Value()) = %v, want %v", i, g, w)
break
}
value, _ = inCounter.Value()
if g, w := CounterSide(value), side; g == w {
t.Errorf("after %d iterations, CounterSide(inCounter.Value()) = %v, want %v", i, g, w)
break
}
outCounter.Inc()
inCounter.Inc()
}
}
}
func TestCounterInc(t *testing.T) {
for _, test := range []struct {
counter []byte
want []byte
}{
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
},
{
counter: []byte{0xff, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x42, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x43, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
},
} {
c := CounterFromValue(test.counter, overflowLenAES128GCM)
c.Inc()
value, _ := c.Value()
if g, w := value, test.want; !bytes.Equal(g, w) || c.invalid {
t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, g, w)
}
}
}
func TestRolloverCounter(t *testing.T) {
for _, test := range []struct {
desc string
value []byte
overflowLen int
}{
{
desc: "testing overflow without rekeying 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
overflowLen: 5,
},
{
desc: "testing overflow without rekeying 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
overflowLen: 5,
},
{
desc: "testing overflow for rekeying mode 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x80},
overflowLen: 8,
},
{
desc: "testing overflow for rekeying mode 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00},
overflowLen: 8,
},
} {
c := CounterFromValue(test.value, overflowLenAES128GCM)
// First Inc() + Value() should work.
c.Inc()
_, err := c.Value()
if err != nil {
t.Errorf("%v: first Inc() + Value() unexpectedly failed: %v, want <nil> error", test.desc, err)
}
// Second Inc() + Value() should fail.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: second Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
// Third Inc() + Value() should also fail because the counter is
// already in an invalid state.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: Third Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
}
}

View File

@ -1,271 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package conn contains an implementation of a secure channel created by gRPC
// handshakers.
package conn
import (
"encoding/binary"
"fmt"
"math"
"net"
core "google.golang.org/grpc/credentials/alts/internal"
)
// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
type ALTSRecordCrypto interface {
// Encrypt encrypts the plaintext and computes the tag (if any) of dst
// and plaintext, dst and plaintext do not overlap.
Encrypt(dst, plaintext []byte) ([]byte, error)
// EncryptionOverhead returns the tag size (if any) in bytes.
EncryptionOverhead() int
// Decrypt decrypts ciphertext and verify the tag (if any). dst and
// ciphertext may alias exactly or not at all. To reuse ciphertext's
// storage for the decrypted output, use ciphertext[:0] as dst.
Decrypt(dst, ciphertext []byte) ([]byte, error)
}
// ALTSRecordFunc is a function type for factory functions that create
// ALTSRecordCrypto instances.
type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error)
const (
// MsgLenFieldSize is the byte size of the frame length field of a
// framed message.
MsgLenFieldSize = 4
// The byte size of the message type field of a framed message.
msgTypeFieldSize = 4
// The bytes size limit for a ALTS record message.
altsRecordLengthLimit = 1024 * 1024 // 1 MiB
// The default bytes size of a ALTS record message.
altsRecordDefaultLength = 4 * 1024 // 4KiB
// Message type value included in ALTS record framing.
altsRecordMsgType = uint32(0x06)
// The initial write buffer size.
altsWriteBufferInitialSize = 32 * 1024 // 32KiB
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
)
var (
protocols = make(map[string]ALTSRecordFunc)
)
// RegisterProtocol register a ALTS record encryption protocol.
func RegisterProtocol(protocol string, f ALTSRecordFunc) error {
if _, ok := protocols[protocol]; ok {
return fmt.Errorf("protocol %v is already registered", protocol)
}
protocols[protocol] = f
return nil
}
// conn represents a secured connection. It implements the net.Conn interface.
type conn struct {
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
// decrypted. This data might not compose a complete frame.
protected []byte
// writeBuf is a buffer used to contain encrypted frames before being
// written to the network.
writeBuf []byte
// nextFrame stores the next frame (in protected buffer) info.
nextFrame []byte
// overhead is the calculated overhead of each frame.
overhead int
}
// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
}
crypto, err := newCrypto(side, key)
if err != nil {
return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protected = make([]byte, 0, 2*altsRecordDefaultLength-1)
}
altsConn := &conn{
Conn: c,
crypto: crypto,
payloadLengthLimit: payloadLengthLimit,
protected: protected,
writeBuf: make([]byte, altsWriteBufferInitialSize),
nextFrame: protected,
overhead: overhead,
}
return altsConn, nil
}
// Read reads and decrypts a frame from the underlying connection, and copies the
// decrypted payload into b. If the size of the payload is greater than len(b),
// Read retains the remaining bytes in an internal buffer, and subsequent calls
// to Read will read from this buffer until it is exhausted.
func (p *conn) Read(b []byte) (n int, err error) {
if len(p.buf) == 0 {
var framedMsg []byte
framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
if err != nil {
return n, err
}
// Check whether the next frame to be decrypted has been
// completely received yet.
if len(framedMsg) == 0 {
copy(p.protected, p.nextFrame)
p.protected = p.protected[:len(p.nextFrame)]
// Always copy next incomplete frame to the beginning of
// the protected buffer and reset nextFrame to it.
p.nextFrame = p.protected
}
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
}
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
if err != nil {
return 0, err
}
p.protected = p.protected[:len(p.protected)+n]
framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
if err != nil {
return 0, err
}
}
// Now we have a complete frame, decrypted it.
msg := framedMsg[MsgLenFieldSize:]
msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize])
if msgType&0xff != altsRecordMsgType {
return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v",
msgType, altsRecordMsgType)
}
ciphertext := msg[msgTypeFieldSize:]
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
// ciphertext, so they alias inexactly. Using ciphertext[:0]
// arranges the appropriate aliasing without needing to copy
// ciphertext or use a separate destination buffer. For more info
// check: https://golang.org/pkg/crypto/cipher/#AEAD.
p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
if err != nil {
return 0, err
}
}
n = copy(b, p.buf)
p.buf = p.buf[n:]
return n, nil
}
// Write encrypts, frames, and writes bytes from b to the underlying connection.
func (p *conn) Write(b []byte) (n int, err error) {
n = len(b)
// Calculate the output buffer size with framing and encryption overhead.
numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit)))
size := len(b) + numOfFrames*p.overhead
// If writeBuf is too small, increase its size up to the maximum size.
partialBSize := len(b)
if size > altsWriteBufferMaxSize {
size = altsWriteBufferMaxSize
const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
}
if len(p.writeBuf) < size {
p.writeBuf = make([]byte, size)
}
for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize {
partialBEnd := partialBStart + partialBSize
if partialBEnd > len(b) {
partialBEnd = len(b)
}
partialB := b[partialBStart:partialBEnd]
writeBufIndex := 0
for len(partialB) > 0 {
payloadLen := len(partialB)
if payloadLen > p.payloadLengthLimit {
payloadLen = p.payloadLengthLimit
}
buf := partialB[:payloadLen]
partialB = partialB[payloadLen:]
// Write buffer contains: length, type, payload, and tag
// if any.
// 1. Fill in type field.
msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg, altsRecordMsgType)
// 2. Encrypt the payload and create a tag if any.
msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf)
if err != nil {
return n, err
}
// 3. Fill in the size field.
binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg)))
// 4. Increase writeBufIndex.
writeBufIndex += len(buf) + p.overhead
}
nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex])
if err != nil {
// We need to calculate the actual data size that was
// written. This means we need to remove header,
// encryption overheads, and any partially-written
// frame data.
numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
}
}
return n, nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -1,274 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"net"
"reflect"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
var (
nextProtocols = []string{"ALTSRP_GCM_AES128"}
altsRecordFuncs = map[string]ALTSRecordFunc{
// ALTS handshaker protocols.
"ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
return NewAES128GCM(s, keyData)
},
}
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
}
func (c *testConn) Read(b []byte) (n int, err error) {
return c.in.Read(b)
}
func (c *testConn) Write(b []byte) (n int, err error) {
return c.out.Write(b)
}
func (c *testConn) Close() error {
return nil
}
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn {
key := []byte{
// 16 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
tc := testConn{
in: in,
out: out,
}
c, err := NewConn(&tc, side, np, key, nil)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}
return c.(*conn)
}
func newConnPair(np string) (client, server *conn) {
clientBuf := new(bytes.Buffer)
serverBuf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np)
return clientConn, serverConn
}
func testPingPong(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
clientMsg := []byte("Client Message")
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
rcvClientMsg := make([]byte, len(clientMsg))
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
serverMsg := []byte("Server Message")
if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil {
t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg))
}
rcvServerMsg := make([]byte, len(serverMsg))
if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil {
t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg))
}
if !reflect.DeepEqual(serverMsg, rcvServerMsg) {
t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg)
}
}
func TestPingPong(t *testing.T) {
for _, np := range nextProtocols {
testPingPong(t, np)
}
}
func testSmallReadBuffer(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
msg := []byte("Very Important Message")
if n, err := clientConn.Write(msg); err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
n := 2 // Arbitrary index to break rcvMsg in two.
rcvMsg1 := rcvMsg[:n]
rcvMsg2 := rcvMsg[n:]
if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1))
}
if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg)
}
}
func TestSmallReadBuffer(t *testing.T) {
for _, np := range nextProtocols {
testSmallReadBuffer(t, np)
}
}
func testLargeMsg(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
// msgLen is such that the length in the framing is larger than the
// default size of one frame.
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
msg := make([]byte, msgLen)
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}
func TestLargeMsg(t *testing.T) {
for _, np := range nextProtocols {
testLargeMsg(t, np)
}
}
func testIncorrectMsgType(t *testing.T, np string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize)
wrongMsgType := uint32(0x22)
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
in := bytes.NewBuffer(framedMsg)
c := newTestALTSRecordConn(in, nil, core.ClientSide, np)
b := make([]byte, 1)
if n, err := c.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
}
}
func TestIncorrectMsgType(t *testing.T) {
for _, np := range nextProtocols {
testIncorrectMsgType(t, np)
}
}
func testFrameTooLarge(t *testing.T, np string) {
buf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np)
// payloadLen is such that the length in the framing is larger than
// allowed in one frame.
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
payload := make([]byte, payloadLen)
c, err := clientConn.crypto.Encrypt(nil, payload)
if err != nil {
t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err))
}
msgLen := msgTypeFieldSize + len(c)
framedMsg := make([]byte, MsgLenFieldSize+msgLen)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c)))
msg := framedMsg[MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType)
copy(msg[msgTypeFieldSize:], c)
if _, err = buf.Write(framedMsg); err != nil {
t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err))
}
b := make([]byte, 1)
if n, err := serverConn.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit))
}
}
func TestFrameTooLarge(t *testing.T) {
for _, np := range nextProtocols {
testFrameTooLarge(t, np)
}
}
func testWriteLargeData(t *testing.T, np string) {
// Test sending and receiving messages larger than the maximum write
// buffer size.
clientConn, serverConn := newConnPair(np)
// Message size is intentionally chosen to not be multiple of
// payloadLengthLimtit.
msgSize := altsWriteBufferMaxSize + (100 * 1024)
clientMsg := make([]byte, msgSize)
for i := 0; i < msgSize; i++ {
clientMsg[i] = 0xAA
}
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
// We need to keep reading until the entire message is received. The
// reason we set all bytes of the message to a value other than zero is
// to avoid ambiguous zero-init value of rcvClientMsg buffer and the
// actual received data.
rcvClientMsg := make([]byte, 0, msgSize)
numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit)))
for i := 0; i < numberOfExpectedFrames; i++ {
expectedRcvSize := serverConn.payloadLengthLimit
if i == numberOfExpectedFrames-1 {
// Last frame might be smaller.
expectedRcvSize = msgSize % serverConn.payloadLengthLimit
}
tmpBuf := make([]byte, expectedRcvSize)
if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf))
}
rcvClientMsg = append(rcvClientMsg, tmpBuf...)
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
}
func TestWriteLargeData(t *testing.T) {
for _, np := range nextProtocols {
testWriteLargeData(t, np)
}
}

View File

@ -1,63 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import core "google.golang.org/grpc/credentials/alts/internal"
// NewOutCounter returns an outgoing counter initialized to the starting sequence
// number for the client/server side of a connection.
func NewOutCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ServerSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// NewInCounter returns an incoming counter initialized to the starting sequence
// number for the client/server side of a connection. This is used in ALTS record
// to check that incoming counters are as expected, since ALTS record guarantees
// that messages are unwrapped in the same order that the peer wrapped them.
func NewInCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ClientSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// CounterFromValue creates a new counter given an initial value.
func CounterFromValue(value []byte, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
copy(c.value[:], value)
return
}
// CounterSide returns the connection side (client/server) a sequence counter is
// associated with.
func CounterSide(c []byte) core.Side {
if c[counterLen-1]&0x80 == 0x80 {
return core.ServerSide
}
return core.ClientSide
}

View File

@ -1,365 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package handshaker provides ALTS handshaking functionality for GCP.
package handshaker
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/authinfo"
"google.golang.org/grpc/credentials/alts/internal/conn"
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
// The maximum byte size of receive frames.
frameLimit = 64 * 1024 // 64 KB
rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
// maxPendingHandshakes represents the maximum number of concurrent
// handshakes.
maxPendingHandshakes = 100
)
var (
hsProtocol = altspb.HandshakeProtocol_ALTS
appProtocols = []string{"grpc"}
recordProtocols = []string{rekeyRecordProtocolName}
keyLength = map[string]int{
rekeyRecordProtocolName: 44,
}
altsRecordFuncs = map[string]conn.ALTSRecordFunc{
// ALTS handshaker protocols.
rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
return conn.NewAES128GCMRekey(s, keyData)
},
}
// control number of concurrent created (but not closed) handshakers.
mu sync.Mutex
concurrentHandshakes = int64(0)
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := conn.RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
func acquire(n int64) bool {
mu.Lock()
success := maxPendingHandshakes-concurrentHandshakes >= n
if success {
concurrentHandshakes += n
}
mu.Unlock()
return success
}
func release(n int64) {
mu.Lock()
concurrentHandshakes -= n
if concurrentHandshakes < 0 {
mu.Unlock()
panic("bad release")
}
mu.Unlock()
}
// ClientHandshakerOptions contains the client handshaker options that can
// provided by the caller.
type ClientHandshakerOptions struct {
// ClientIdentity is the handshaker client local identity.
ClientIdentity *altspb.Identity
// TargetName is the server service account name for secure name
// checking.
TargetName string
// TargetServiceAccounts contains a list of expected target service
// accounts. One of these accounts should match one of the accounts in
// the handshaker results. Otherwise, the handshake fails.
TargetServiceAccounts []string
// RPCVersions specifies the gRPC versions accepted by the client.
RPCVersions *altspb.RpcProtocolVersions
}
// ServerHandshakerOptions contains the server handshaker options that can
// provided by the caller.
type ServerHandshakerOptions struct {
// RPCVersions specifies the gRPC versions accepted by the server.
RPCVersions *altspb.RpcProtocolVersions
}
// DefaultClientHandshakerOptions returns the default client handshaker options.
func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
return &ClientHandshakerOptions{}
}
// DefaultServerHandshakerOptions returns the default client handshaker options.
func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
return &ServerHandshakerOptions{}
}
// TODO: add support for future local and remote endpoint in both client options
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.
// altsHandshaker is used to complete a ALTS handshaking between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
serverOpts *ServerHandshakerOptions
// defines the side doing the handshake, client or server.
side core.Side
}
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
clientOpts: opts,
side: core.ClientSide,
}, nil
}
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
serverOpts: opts,
side: core.ServerSide,
}, nil
}
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ClientSide {
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}
// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
targetIdentities = append(targetIdentities, &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: account,
},
})
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ClientStart{
ClientStart: &altspb.StartClientHandshakeReq{
HandshakeSecurityProtocol: hsProtocol,
ApplicationProtocols: appProtocols,
RecordProtocols: recordProtocols,
TargetIdentities: targetIdentities,
LocalIdentity: h.clientOpts.ClientIdentity,
TargetName: h.clientOpts.TargetName,
RpcVersions: h.clientOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ServerSide {
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}
p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
return nil, nil, err
}
// Prepare server parameters.
// TODO: currently only ALTS parameters are provided. Might need to use
// more options in the future.
params := make(map[int32]*altspb.ServerHandshakeParameters)
params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
RecordProtocols: recordProtocols,
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ServerStart{
ServerStart: &altspb.StartServerHandshakeReq{
ApplicationProtocols: appProtocols,
HandshakeParameters: params,
InBytes: p[:n],
RpcVersions: h.serverOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
resp, err := h.accessHandshakerService(req)
if err != nil {
return nil, nil, err
}
// Check of the returned status is an error.
if resp.GetStatus() != nil {
if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
}
}
var extra []byte
if req.GetServerStart() != nil {
extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
}
result, extra, err := h.processUntilDone(resp, extra)
if err != nil {
return nil, nil, err
}
// The handshaker returns a 128 bytes key. It should be truncated based
// on the returned record protocol.
keyLen, ok := keyLength[result.RecordProtocol]
if !ok {
return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
}
sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
if err != nil {
return nil, nil, err
}
return sc, result, nil
}
func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
if err := h.stream.Send(req); err != nil {
return nil, err
}
resp, err := h.stream.Recv()
if err != nil {
return nil, err
}
return resp, nil
}
// processUntilDone processes the handshake until the handshaker service returns
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
for {
if len(resp.OutFrames) > 0 {
if _, err := h.conn.Write(resp.OutFrames); err != nil {
return nil, nil, err
}
}
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err
}
// If there is nothing to send to the handshaker service, and
// nothing is received from the peer, then we are stuck.
// This covers the case when the peer is not responding. Note
// that handshaker service connection issues are caught in
// accessHandshakerService before we even get here.
if len(resp.OutFrames) == 0 && n == 0 {
return nil, nil, core.PeerNotRespondingError
}
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
p := append(extra, buf[:n]...)
resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_Next{
Next: &altspb.NextHandshakeMessageReq{
InBytes: p,
},
},
})
if err != nil {
return nil, nil, err
}
// Set extra based on handshaker service response.
if n == 0 {
extra = nil
} else {
extra = buf[resp.GetBytesConsumed():n]
}
}
}
// Close terminates the Handshaker. It should be called when the caller obtains
// the secure connection.
func (h *altsHandshaker) Close() {
h.stream.CloseSend()
}

View File

@ -1,261 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package handshaker
import (
"bytes"
"context"
"testing"
"time"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/credentials/alts/internal/testutil"
)
var (
testAppProtocols = []string{"grpc"}
testRecordProtocol = rekeyRecordProtocolName
testKey = []byte{
// 44 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
}
testServiceAccount = "test_service_account"
testTargetServiceAccounts = []string{testServiceAccount}
testClientIdentity = &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: "i_am_a_client",
},
}
)
// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
type testRPCStream struct {
grpc.ClientStream
t *testing.T
isClient bool
// The resp expected to be returned by Recv(). Make sure this is set to
// the content the test requires before Recv() is invoked.
recvBuf *altspb.HandshakerResp
// false if it is the first access to Handshaker service on Envelope.
first bool
// useful for testing concurrent calls.
delay time.Duration
}
func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
resp := t.recvBuf
t.recvBuf = nil
return resp, nil
}
func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
var resp *altspb.HandshakerResp
if !t.first {
// Generate the bytes to be returned by Recv() for the initial
// handshaking.
t.first = true
if t.isClient {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ClientInit"),
// Simulate consuming ServerInit.
BytesConsumed: 14,
}
} else {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ServerInit"),
// Simulate consuming ClientInit.
BytesConsumed: 14,
}
}
} else {
// Add delay to test concurrent calls.
cleanup := stat.Update()
defer cleanup()
time.Sleep(t.delay)
// Generate the response to be returned by Recv() for the
// follow-up handshaking.
result := &altspb.HandshakerResult{
RecordProtocol: testRecordProtocol,
KeyData: testKey,
}
resp = &altspb.HandshakerResp{
Result: result,
// Simulate consuming ClientFinished or ServerFinished.
BytesConsumed: 18,
}
}
t.recvBuf = resp
return nil
}
func (t *testRPCStream) CloseSend() error {
return nil
}
var stat testutil.Stats
func TestClientHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: true,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ServerInit")
f2 := testutil.MakeFrame("ServerFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
chs := &altsHandshaker{
stream: stream,
conn: tc,
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
go func() {
_, context, err := chs.ClientHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
chs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
func TestServerHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: false,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ClientInit")
f2 := testutil.MakeFrame("ClientFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
shs := &altsHandshaker{
stream: stream,
conn: tc,
serverOpts: DefaultServerHandshakerOptions(),
side: core.ServerSide,
}
go func() {
_, context, err := shs.ServerHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
shs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
// testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
type testUnresponsiveRPCStream struct {
grpc.ClientStream
}
func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
return &altspb.HandshakerResp{}, nil
}
func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
return nil
}
func (t *testUnresponsiveRPCStream) CloseSend() error {
return nil
}
func TestPeerNotResponding(t *testing.T) {
stream := &testUnresponsiveRPCStream{}
chs := &altsHandshaker{
stream: stream,
conn: testutil.NewUnresponsiveTestConn(),
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
_, context, err := chs.ClientHandshake(context.Background())
chs.Close()
if context != nil {
t.Error("expected non-nil ALTS context")
}
if got, want := err, core.PeerNotRespondingError; got != want {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

View File

@ -1,56 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package service manages connections between the VM application and the ALTS
// handshaker service.
package service
import (
"sync"
grpc "google.golang.org/grpc"
)
var (
// hsConn represents a connection to hypervisor handshaker service.
hsConn *grpc.ClientConn
mu sync.Mutex
// hsDialer will be reassigned in tests.
hsDialer = grpc.Dial
)
type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error)
// Dial dials the handshake service in the hypervisor. If a connection has
// already been established, this function returns it. Otherwise, a new
// connection is created.
func Dial(hsAddress string) (*grpc.ClientConn, error) {
mu.Lock()
defer mu.Unlock()
if hsConn == nil {
// Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed.
var err error
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
if err != nil {
return nil, err
}
}
return hsConn, nil
}

View File

@ -1,69 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"testing"
grpc "google.golang.org/grpc"
)
const (
// The address is irrelevant in this test.
testAddress = "some_address"
)
func TestDial(t *testing.T) {
defer func() func() {
temp := hsDialer
hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
return func() {
hsDialer = temp
}
}()
// Ensure that hsConn is nil at first.
hsConn = nil
// First call to Dial, it should create set hsConn.
conn1, err := Dial(testAddress)
if err != nil {
t.Fatalf("first call to Dial failed: %v", err)
}
if conn1 == nil {
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
// Second call to Dial should return conn1 above.
conn2, err := Dial(testAddress)
if err != nil {
t.Fatalf("second call to Dial(_) failed: %v", err)
}
if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
}

View File

@ -1,151 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/gcp/altscontext.proto
package grpc_gcp // import "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type AltsContext struct {
// The application protocol negotiated for this connection.
ApplicationProtocol string `protobuf:"bytes,1,opt,name=application_protocol,json=applicationProtocol,proto3" json:"application_protocol,omitempty"`
// The record protocol negotiated for this connection.
RecordProtocol string `protobuf:"bytes,2,opt,name=record_protocol,json=recordProtocol,proto3" json:"record_protocol,omitempty"`
// The security level of the created secure channel.
SecurityLevel SecurityLevel `protobuf:"varint,3,opt,name=security_level,json=securityLevel,proto3,enum=grpc.gcp.SecurityLevel" json:"security_level,omitempty"`
// The peer service account.
PeerServiceAccount string `protobuf:"bytes,4,opt,name=peer_service_account,json=peerServiceAccount,proto3" json:"peer_service_account,omitempty"`
// The local service account.
LocalServiceAccount string `protobuf:"bytes,5,opt,name=local_service_account,json=localServiceAccount,proto3" json:"local_service_account,omitempty"`
// The RPC protocol versions supported by the peer.
PeerRpcVersions *RpcProtocolVersions `protobuf:"bytes,6,opt,name=peer_rpc_versions,json=peerRpcVersions,proto3" json:"peer_rpc_versions,omitempty"`
// Additional attributes of the peer.
PeerAttributes map[string]string `protobuf:"bytes,7,rep,name=peer_attributes,json=peerAttributes,proto3" json:"peer_attributes,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *AltsContext) Reset() { *m = AltsContext{} }
func (m *AltsContext) String() string { return proto.CompactTextString(m) }
func (*AltsContext) ProtoMessage() {}
func (*AltsContext) Descriptor() ([]byte, []int) {
return fileDescriptor_altscontext_f6b7868f9a30497f, []int{0}
}
func (m *AltsContext) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_AltsContext.Unmarshal(m, b)
}
func (m *AltsContext) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_AltsContext.Marshal(b, m, deterministic)
}
func (dst *AltsContext) XXX_Merge(src proto.Message) {
xxx_messageInfo_AltsContext.Merge(dst, src)
}
func (m *AltsContext) XXX_Size() int {
return xxx_messageInfo_AltsContext.Size(m)
}
func (m *AltsContext) XXX_DiscardUnknown() {
xxx_messageInfo_AltsContext.DiscardUnknown(m)
}
var xxx_messageInfo_AltsContext proto.InternalMessageInfo
func (m *AltsContext) GetApplicationProtocol() string {
if m != nil {
return m.ApplicationProtocol
}
return ""
}
func (m *AltsContext) GetRecordProtocol() string {
if m != nil {
return m.RecordProtocol
}
return ""
}
func (m *AltsContext) GetSecurityLevel() SecurityLevel {
if m != nil {
return m.SecurityLevel
}
return SecurityLevel_SECURITY_NONE
}
func (m *AltsContext) GetPeerServiceAccount() string {
if m != nil {
return m.PeerServiceAccount
}
return ""
}
func (m *AltsContext) GetLocalServiceAccount() string {
if m != nil {
return m.LocalServiceAccount
}
return ""
}
func (m *AltsContext) GetPeerRpcVersions() *RpcProtocolVersions {
if m != nil {
return m.PeerRpcVersions
}
return nil
}
func (m *AltsContext) GetPeerAttributes() map[string]string {
if m != nil {
return m.PeerAttributes
}
return nil
}
func init() {
proto.RegisterType((*AltsContext)(nil), "grpc.gcp.AltsContext")
proto.RegisterMapType((map[string]string)(nil), "grpc.gcp.AltsContext.PeerAttributesEntry")
}
func init() {
proto.RegisterFile("grpc/gcp/altscontext.proto", fileDescriptor_altscontext_f6b7868f9a30497f)
}
var fileDescriptor_altscontext_f6b7868f9a30497f = []byte{
// 411 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x92, 0x4d, 0x6f, 0x13, 0x31,
0x10, 0x86, 0xb5, 0x0d, 0x2d, 0xe0, 0x88, 0xb4, 0xb8, 0xa9, 0x58, 0x45, 0x42, 0x8a, 0xb8, 0xb0,
0x5c, 0x76, 0x21, 0x5c, 0x10, 0x07, 0x50, 0x8a, 0x38, 0x20, 0x71, 0x88, 0xb6, 0x12, 0x07, 0x2e,
0x2b, 0x77, 0x3a, 0xb2, 0x2c, 0x5c, 0x8f, 0x35, 0x76, 0x22, 0xf2, 0xb3, 0xf9, 0x07, 0x68, 0xed,
0xcd, 0x07, 0x1f, 0xb7, 0x9d, 0x79, 0x9f, 0x19, 0xbf, 0xb3, 0x33, 0x62, 0xa6, 0xd9, 0x43, 0xa3,
0xc1, 0x37, 0xca, 0xc6, 0x00, 0xe4, 0x22, 0xfe, 0x8c, 0xb5, 0x67, 0x8a, 0x24, 0x1f, 0xf5, 0x5a,
0xad, 0xc1, 0xcf, 0xaa, 0x3d, 0x15, 0x59, 0xb9, 0xe0, 0x89, 0x63, 0x17, 0x10, 0xd6, 0x6c, 0xe2,
0xb6, 0x03, 0xba, 0xbf, 0x27, 0x97, 0x6b, 0x5e, 0xfc, 0x1a, 0x89, 0xf1, 0xd2, 0xc6, 0xf0, 0x29,
0x77, 0x92, 0x6f, 0xc4, 0x54, 0x79, 0x6f, 0x0d, 0xa8, 0x68, 0xc8, 0x75, 0x09, 0x02, 0xb2, 0x65,
0x31, 0x2f, 0xaa, 0xc7, 0xed, 0xe5, 0x91, 0xb6, 0x1a, 0x24, 0xf9, 0x52, 0x9c, 0x33, 0x02, 0xf1,
0xdd, 0x81, 0x3e, 0x49, 0xf4, 0x24, 0xa7, 0xf7, 0xe0, 0x07, 0x31, 0xd9, 0x9b, 0xb0, 0xb8, 0x41,
0x5b, 0x8e, 0xe6, 0x45, 0x35, 0x59, 0x3c, 0xab, 0x77, 0xc6, 0xeb, 0x9b, 0x41, 0xff, 0xda, 0xcb,
0xed, 0x93, 0x70, 0x1c, 0xca, 0xd7, 0x62, 0xea, 0x11, 0xb9, 0x0b, 0xc8, 0x1b, 0x03, 0xd8, 0x29,
0x00, 0x5a, 0xbb, 0x58, 0x3e, 0x48, 0xaf, 0xc9, 0x5e, 0xbb, 0xc9, 0xd2, 0x32, 0x2b, 0x72, 0x21,
0xae, 0x2c, 0x81, 0xb2, 0xff, 0x94, 0x9c, 0xe6, 0x71, 0x92, 0xf8, 0x57, 0xcd, 0x17, 0xf1, 0x34,
0xbd, 0xc2, 0x1e, 0xba, 0x0d, 0x72, 0x30, 0xe4, 0x42, 0x79, 0x36, 0x2f, 0xaa, 0xf1, 0xe2, 0xf9,
0xc1, 0x68, 0xeb, 0x61, 0x37, 0xd7, 0xb7, 0x01, 0x6a, 0xcf, 0xfb, 0xba, 0xd6, 0xc3, 0x2e, 0x21,
0x5b, 0x91, 0x52, 0x9d, 0x8a, 0x91, 0xcd, 0xed, 0x3a, 0x62, 0x28, 0x1f, 0xce, 0x47, 0xd5, 0x78,
0xf1, 0xea, 0xd0, 0xe8, 0xe8, 0xe7, 0xd7, 0x2b, 0x44, 0x5e, 0xee, 0xd9, 0xcf, 0x2e, 0xf2, 0xb6,
0x9d, 0xf8, 0x3f, 0x92, 0xb3, 0xa5, 0xb8, 0xfc, 0x0f, 0x26, 0x2f, 0xc4, 0xe8, 0x07, 0x6e, 0x87,
0x35, 0xf5, 0x9f, 0x72, 0x2a, 0x4e, 0x37, 0xca, 0xae, 0x71, 0x58, 0x46, 0x0e, 0xde, 0x9f, 0xbc,
0x2b, 0xae, 0xad, 0xb8, 0x32, 0x94, 0x1d, 0xf4, 0x47, 0x54, 0x1b, 0x17, 0x91, 0x9d, 0xb2, 0xd7,
0x17, 0x47, 0x66, 0xd2, 0x74, 0xab, 0xe2, 0xfb, 0x47, 0x4d, 0xa4, 0x2d, 0xd6, 0x9a, 0xac, 0x72,
0xba, 0x26, 0xd6, 0x4d, 0x3a, 0x2e, 0x60, 0xbc, 0x43, 0x17, 0x8d, 0xb2, 0x21, 0x9d, 0x62, 0xb3,
0xeb, 0xd2, 0xa4, 0x2b, 0x48, 0x50, 0xa7, 0xc1, 0xdf, 0x9e, 0xa5, 0xf8, 0xed, 0xef, 0x00, 0x00,
0x00, 0xff, 0xff, 0x9b, 0x8c, 0xe4, 0x6a, 0xba, 0x02, 0x00, 0x00,
}

File diff suppressed because it is too large Load Diff

View File

@ -1,178 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/gcp/transport_security_common.proto
package grpc_gcp // import "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The security level of the created channel. The list is sorted in increasing
// level of security. This order must always be maintained.
type SecurityLevel int32
const (
SecurityLevel_SECURITY_NONE SecurityLevel = 0
SecurityLevel_INTEGRITY_ONLY SecurityLevel = 1
SecurityLevel_INTEGRITY_AND_PRIVACY SecurityLevel = 2
)
var SecurityLevel_name = map[int32]string{
0: "SECURITY_NONE",
1: "INTEGRITY_ONLY",
2: "INTEGRITY_AND_PRIVACY",
}
var SecurityLevel_value = map[string]int32{
"SECURITY_NONE": 0,
"INTEGRITY_ONLY": 1,
"INTEGRITY_AND_PRIVACY": 2,
}
func (x SecurityLevel) String() string {
return proto.EnumName(SecurityLevel_name, int32(x))
}
func (SecurityLevel) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0}
}
// Max and min supported RPC protocol versions.
type RpcProtocolVersions struct {
// Maximum supported RPC version.
MaxRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,1,opt,name=max_rpc_version,json=maxRpcVersion,proto3" json:"max_rpc_version,omitempty"`
// Minimum supported RPC version.
MinRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,2,opt,name=min_rpc_version,json=minRpcVersion,proto3" json:"min_rpc_version,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RpcProtocolVersions) Reset() { *m = RpcProtocolVersions{} }
func (m *RpcProtocolVersions) String() string { return proto.CompactTextString(m) }
func (*RpcProtocolVersions) ProtoMessage() {}
func (*RpcProtocolVersions) Descriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0}
}
func (m *RpcProtocolVersions) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RpcProtocolVersions.Unmarshal(m, b)
}
func (m *RpcProtocolVersions) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RpcProtocolVersions.Marshal(b, m, deterministic)
}
func (dst *RpcProtocolVersions) XXX_Merge(src proto.Message) {
xxx_messageInfo_RpcProtocolVersions.Merge(dst, src)
}
func (m *RpcProtocolVersions) XXX_Size() int {
return xxx_messageInfo_RpcProtocolVersions.Size(m)
}
func (m *RpcProtocolVersions) XXX_DiscardUnknown() {
xxx_messageInfo_RpcProtocolVersions.DiscardUnknown(m)
}
var xxx_messageInfo_RpcProtocolVersions proto.InternalMessageInfo
func (m *RpcProtocolVersions) GetMaxRpcVersion() *RpcProtocolVersions_Version {
if m != nil {
return m.MaxRpcVersion
}
return nil
}
func (m *RpcProtocolVersions) GetMinRpcVersion() *RpcProtocolVersions_Version {
if m != nil {
return m.MinRpcVersion
}
return nil
}
// RPC version contains a major version and a minor version.
type RpcProtocolVersions_Version struct {
Major uint32 `protobuf:"varint,1,opt,name=major,proto3" json:"major,omitempty"`
Minor uint32 `protobuf:"varint,2,opt,name=minor,proto3" json:"minor,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RpcProtocolVersions_Version) Reset() { *m = RpcProtocolVersions_Version{} }
func (m *RpcProtocolVersions_Version) String() string { return proto.CompactTextString(m) }
func (*RpcProtocolVersions_Version) ProtoMessage() {}
func (*RpcProtocolVersions_Version) Descriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0, 0}
}
func (m *RpcProtocolVersions_Version) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RpcProtocolVersions_Version.Unmarshal(m, b)
}
func (m *RpcProtocolVersions_Version) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RpcProtocolVersions_Version.Marshal(b, m, deterministic)
}
func (dst *RpcProtocolVersions_Version) XXX_Merge(src proto.Message) {
xxx_messageInfo_RpcProtocolVersions_Version.Merge(dst, src)
}
func (m *RpcProtocolVersions_Version) XXX_Size() int {
return xxx_messageInfo_RpcProtocolVersions_Version.Size(m)
}
func (m *RpcProtocolVersions_Version) XXX_DiscardUnknown() {
xxx_messageInfo_RpcProtocolVersions_Version.DiscardUnknown(m)
}
var xxx_messageInfo_RpcProtocolVersions_Version proto.InternalMessageInfo
func (m *RpcProtocolVersions_Version) GetMajor() uint32 {
if m != nil {
return m.Major
}
return 0
}
func (m *RpcProtocolVersions_Version) GetMinor() uint32 {
if m != nil {
return m.Minor
}
return 0
}
func init() {
proto.RegisterType((*RpcProtocolVersions)(nil), "grpc.gcp.RpcProtocolVersions")
proto.RegisterType((*RpcProtocolVersions_Version)(nil), "grpc.gcp.RpcProtocolVersions.Version")
proto.RegisterEnum("grpc.gcp.SecurityLevel", SecurityLevel_name, SecurityLevel_value)
}
func init() {
proto.RegisterFile("grpc/gcp/transport_security_common.proto", fileDescriptor_transport_security_common_71945991f2c3b4a6)
}
var fileDescriptor_transport_security_common_71945991f2c3b4a6 = []byte{
// 323 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x91, 0x41, 0x4b, 0x3b, 0x31,
0x10, 0xc5, 0xff, 0x5b, 0xf8, 0xab, 0x44, 0x56, 0xeb, 0x6a, 0x41, 0xc5, 0x83, 0x08, 0x42, 0xf1,
0x90, 0x05, 0xc5, 0xb3, 0xb4, 0xb5, 0x48, 0xa1, 0x6e, 0xeb, 0xb6, 0x16, 0xea, 0x25, 0xc4, 0x18,
0x42, 0x24, 0x9b, 0x09, 0xb3, 0xb1, 0xd4, 0xaf, 0xec, 0xa7, 0x90, 0x4d, 0xbb, 0x14, 0xc1, 0x8b,
0xb7, 0xbc, 0xc7, 0xcc, 0x6f, 0x32, 0xf3, 0x48, 0x5b, 0xa1, 0x13, 0xa9, 0x12, 0x2e, 0xf5, 0xc8,
0x6d, 0xe9, 0x00, 0x3d, 0x2b, 0xa5, 0xf8, 0x40, 0xed, 0x3f, 0x99, 0x80, 0xa2, 0x00, 0x4b, 0x1d,
0x82, 0x87, 0x64, 0xa7, 0xaa, 0xa4, 0x4a, 0xb8, 0x8b, 0xaf, 0x88, 0x1c, 0xe6, 0x4e, 0x8c, 0x2b,
0x5b, 0x80, 0x99, 0x49, 0x2c, 0x35, 0xd8, 0x32, 0x79, 0x24, 0xfb, 0x05, 0x5f, 0x32, 0x74, 0x82,
0x2d, 0x56, 0xde, 0x71, 0x74, 0x1e, 0xb5, 0x77, 0xaf, 0x2f, 0x69, 0xdd, 0x4b, 0x7f, 0xe9, 0xa3,
0xeb, 0x47, 0x1e, 0x17, 0x7c, 0x99, 0x3b, 0xb1, 0x96, 0x01, 0xa7, 0xed, 0x0f, 0x5c, 0xe3, 0x6f,
0x38, 0x6d, 0x37, 0xb8, 0xd3, 0x5b, 0xb2, 0x5d, 0x93, 0x8f, 0xc8, 0xff, 0x82, 0xbf, 0x03, 0x86,
0xef, 0xc5, 0xf9, 0x4a, 0x04, 0x57, 0x5b, 0xc0, 0x30, 0xa5, 0x72, 0x2b, 0x71, 0xf5, 0x44, 0xe2,
0xc9, 0xfa, 0x1e, 0x43, 0xb9, 0x90, 0x26, 0x39, 0x20, 0xf1, 0xa4, 0xdf, 0x7b, 0xce, 0x07, 0xd3,
0x39, 0xcb, 0x46, 0x59, 0xbf, 0xf9, 0x2f, 0x49, 0xc8, 0xde, 0x20, 0x9b, 0xf6, 0x1f, 0x82, 0x37,
0xca, 0x86, 0xf3, 0x66, 0x94, 0x9c, 0x90, 0xd6, 0xc6, 0xeb, 0x64, 0xf7, 0x6c, 0x9c, 0x0f, 0x66,
0x9d, 0xde, 0xbc, 0xd9, 0xe8, 0x2e, 0x49, 0x4b, 0xc3, 0x6a, 0x07, 0x6e, 0x7c, 0x49, 0xb5, 0xf5,
0x12, 0x2d, 0x37, 0xdd, 0xb3, 0x69, 0x9d, 0x41, 0x3d, 0xb2, 0x17, 0x12, 0x08, 0x2b, 0x8e, 0xa3,
0x97, 0x3b, 0x05, 0xa0, 0x8c, 0xa4, 0x0a, 0x0c, 0xb7, 0x8a, 0x02, 0xaa, 0x34, 0xc4, 0x27, 0x50,
0xbe, 0x49, 0xeb, 0x35, 0x37, 0x65, 0x5a, 0x11, 0xd3, 0x9a, 0x98, 0x86, 0xe8, 0x42, 0x11, 0x53,
0xc2, 0xbd, 0x6e, 0x05, 0x7d, 0xf3, 0x1d, 0x00, 0x00, 0xff, 0xff, 0x31, 0x14, 0xb4, 0x11, 0xf6,
0x01, 0x00, 0x00,
}

View File

@ -1,35 +0,0 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/gcp
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/altscontext.proto > grpc/gcp/altscontext.proto
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/handshaker.proto > grpc/gcp/handshaker.proto
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/transport_security_common.proto > grpc/gcp/transport_security_common.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/gcp/*.proto
popd
rm -f proto/grpc_gcp/*.pb.go
cp "$TMP"/grpc/gcp/*.pb.go proto/grpc_gcp/

View File

@ -1,125 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package testutil include useful test utilities for the handshaker.
package testutil
import (
"bytes"
"encoding/binary"
"io"
"net"
"sync"
"google.golang.org/grpc/credentials/alts/internal/conn"
)
// Stats is used to collect statistics about concurrent handshake calls.
type Stats struct {
mu sync.Mutex
calls int
MaxConcurrentCalls int
}
// Update updates the statistics by adding one call.
func (s *Stats) Update() func() {
s.mu.Lock()
s.calls++
if s.calls > s.MaxConcurrentCalls {
s.MaxConcurrentCalls = s.calls
}
s.mu.Unlock()
return func() {
s.mu.Lock()
s.calls--
s.mu.Unlock()
}
}
// Reset resets the statistics.
func (s *Stats) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.calls = 0
s.MaxConcurrentCalls = 0
}
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
}
// NewTestConn creates a new instance of testConn object.
func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn {
return &testConn{
in: in,
out: out,
}
}
// Read reads from the in buffer.
func (c *testConn) Read(b []byte) (n int, err error) {
return c.in.Read(b)
}
// Write writes to the out buffer.
func (c *testConn) Write(b []byte) (n int, err error) {
return c.out.Write(b)
}
// Close closes the testConn object.
func (c *testConn) Close() error {
return nil
}
// unresponsiveTestConn mimics a net.Conn for an unresponsive peer. It is used
// for testing the PeerNotResponding case.
type unresponsiveTestConn struct {
net.Conn
}
// NewUnresponsiveTestConn creates a new instance of unresponsiveTestConn object.
func NewUnresponsiveTestConn() net.Conn {
return &unresponsiveTestConn{}
}
// Read reads from the in buffer.
func (c *unresponsiveTestConn) Read(b []byte) (n int, err error) {
return 0, io.EOF
}
// Write writes to the out buffer.
func (c *unresponsiveTestConn) Write(b []byte) (n int, err error) {
return 0, nil
}
// Close closes the TestConn object.
func (c *unresponsiveTestConn) Close() error {
return nil
}
// MakeFrame creates a handshake frame.
func MakeFrame(pl string) []byte {
f := make([]byte, len(pl)+conn.MsgLenFieldSize)
binary.LittleEndian.PutUint32(f, uint32(len(pl)))
copy(f[conn.MsgLenFieldSize:], []byte(pl))
return f
}

View File

@ -1,141 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"google.golang.org/grpc/peer"
)
const (
linuxProductNameFile = "/sys/class/dmi/id/product_name"
windowsCheckCommand = "powershell.exe"
windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS"
powershellOutputFilter = "Manufacturer"
windowsManufacturerRegex = ":(.*)"
)
type platformError string
func (k platformError) Error() string {
return fmt.Sprintf("%s is not supported", string(k))
}
var (
// The following two variables will be reassigned in tests.
runningOS = runtime.GOOS
manufacturerReader = func() (io.Reader, error) {
switch runningOS {
case "linux":
return os.Open(linuxProductNameFile)
case "windows":
cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs)
out, err := cmd.Output()
if err != nil {
return nil, err
}
for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") {
if strings.HasPrefix(line, powershellOutputFilter) {
re := regexp.MustCompile(windowsManufacturerRegex)
name := re.FindString(line)
name = strings.TrimLeft(name, ":")
return strings.NewReader(name), nil
}
}
return nil, errors.New("cannot determine the machine's manufacturer")
default:
return nil, platformError(runningOS)
}
}
vmOnGCP bool
)
// isRunningOnGCP checks whether the local system, without doing a network request is
// running on GCP.
func isRunningOnGCP() bool {
manufacturer, err := readManufacturer()
if err != nil {
log.Fatalf("failure to read manufacturer information: %v", err)
}
name := string(manufacturer)
switch runningOS {
case "linux":
name = strings.TrimSpace(name)
return name == "Google" || name == "Google Compute Engine"
case "windows":
name = strings.Replace(name, " ", "", -1)
name = strings.Replace(name, "\n", "", -1)
name = strings.Replace(name, "\r", "", -1)
return name == "Google"
default:
log.Fatal(platformError(runningOS))
}
return false
}
func readManufacturer() ([]byte, error) {
reader, err := manufacturerReader()
if err != nil {
return nil, err
}
if reader == nil {
return nil, errors.New("got nil reader")
}
manufacturer, err := ioutil.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err)
}
return manufacturer, nil
}
// AuthInfoFromContext extracts the alts.AuthInfo object from the given context,
// if it exists. This API should be used by gRPC server RPC handlers to get
// information about the communicating peer. For client-side, use grpc.Peer()
// CallOption.
func AuthInfoFromContext(ctx context.Context) (AuthInfo, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no Peer found in Context")
}
return AuthInfoFromPeer(p)
}
// AuthInfoFromPeer extracts the alts.AuthInfo object from the given peer, if it
// exists. This API should be used by gRPC clients after obtaining a peer object
// using the grpc.Peer() CallOption.
func AuthInfoFromPeer(p *peer.Peer) (AuthInfo, error) {
altsAuthInfo, ok := p.AuthInfo.(AuthInfo)
if !ok {
return nil, errors.New("no alts.AuthInfo found in Peer")
}
return altsAuthInfo, nil
}

View File

@ -1,139 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"context"
"io"
"strings"
"testing"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/peer"
)
func TestIsRunningOnGCP(t *testing.T) {
for _, tc := range []struct {
description string
testOS string
testReader io.Reader
out bool
}{
// Linux tests.
{"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false},
{"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true},
{"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true},
{"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true},
// Windows tests.
{"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false},
{"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true},
{"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true},
} {
reverseFunc := setup(tc.testOS, tc.testReader)
if got, want := isRunningOnGCP(), tc.out; got != want {
t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want)
}
reverseFunc()
}
}
func setup(testOS string, testReader io.Reader) func() {
tmpOS := runningOS
tmpReader := manufacturerReader
// Set test OS and reader function.
runningOS = testOS
manufacturerReader = func() (io.Reader, error) {
return testReader, nil
}
return func() {
runningOS = tmpOS
manufacturerReader = tmpReader
}
}
func TestAuthInfoFromContext(t *testing.T) {
ctx := context.Background()
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
}
for _, tc := range []struct {
desc string
ctx context.Context
success bool
out AuthInfo
}{
{
"working case",
peer.NewContext(ctx, p),
true,
altsAuthInfo,
},
} {
authInfo, err := AuthInfoFromContext(tc.ctx)
if got, want := (err == nil), tc.success; got != want {
t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
}
if got, want := authInfo, tc.out; got != want {
t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
}
}
}
func TestAuthInfoFromPeer(t *testing.T) {
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
}
for _, tc := range []struct {
desc string
p *peer.Peer
success bool
out AuthInfo
}{
{
"working case",
p,
true,
altsAuthInfo,
},
} {
authInfo, err := AuthInfoFromPeer(tc.p)
if got, want := (err == nil), tc.success; got != want {
t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
}
if got, want := authInfo, tc.out; got != want {
t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
}
}
}
type fakeALTSAuthInfo struct{}
func (*fakeALTSAuthInfo) AuthType() string { return "" }
func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
func (*fakeALTSAuthInfo) RecordProtocol() string { return "" }
func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
return altspb.SecurityLevel_SECURITY_NONE
}
func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" }
func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" }
func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }

Some files were not shown because too many files have changed in this diff Show More