build: move e2e dependencies into e2e/go.mod

Several packages are only used while running the e2e suite. These
packages are less important to update, as the they can not influence the
final executable that is part of the Ceph-CSI container-image.

By moving these dependencies out of the main Ceph-CSI go.mod, it is
easier to identify if a reported CVE affects Ceph-CSI, or only the
testing (like most of the Kubernetes CVEs).

Signed-off-by: Niels de Vos <ndevos@ibm.com>
This commit is contained in:
Niels de Vos
2025-03-04 08:57:28 +01:00
committed by mergify[bot]
parent 15da101b1b
commit bec6090996
8047 changed files with 1407827 additions and 3453 deletions

View File

@ -0,0 +1,201 @@
#vendor
vendor/
# Created by .ignore support plugin (hsz.mobi)
coverage.txt
### Go template
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
### Windows template
# Windows image file caches
Thumbs.db
ehthumbs.db
# Folder config file
Desktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msm
*.msp
# Windows shortcuts
*.lnk
### Kate template
# Swap Files #
.*.kate-swp
.swp.*
### SublimeText template
# cache files for sublime text
*.tmlanguage.cache
*.tmPreferences.cache
*.stTheme.cache
# workspace files are user-specific
*.sublime-workspace
# project files should be checked into the repository, unless a significant
# proportion of contributors will probably not be using SublimeText
# *.sublime-project
# sftp configuration file
sftp-config.json
### Linux template
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff:
.idea
.idea/tasks.xml
.idea/dictionaries
.idea/vcs.xml
.idea/jsLibraryMappings.xml
# Sensitive or high-churn files:
.idea/dataSources.ids
.idea/dataSources.xml
.idea/dataSources.local.xml
.idea/sqlDataSources.xml
.idea/dynamic.xml
.idea/uiDesigner.xml
# Gradle:
.idea/gradle.xml
.idea/libraries
# Mongo Explorer plugin:
.idea/mongoSettings.xml
## File-based project format:
*.iws
## Plugin-specific files:
# IntelliJ
/out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
### Xcode template
# Xcode
#
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
## Build generated
build/
DerivedData/
## Various settings
*.pbxuser
!default.pbxuser
*.mode1v3
!default.mode1v3
*.mode2v3
!default.mode2v3
*.perspectivev3
!default.perspectivev3
xcuserdata/
## Other
*.moved-aside
*.xccheckout
*.xcscmblueprint
### Eclipse template
.metadata
bin/
tmp/
*.tmp
*.bak
*.swp
*~.nib
local.properties
.settings/
.loadpath
.recommenders
# Eclipse Core
.project
# External tool builders
.externalToolBuilders/
# Locally stored "Eclipse launch configurations"
*.launch
# PyDev specific (Python IDE for Eclipse)
*.pydevproject
# CDT-specific (C/C++ Development Tooling)
.cproject
# JDT-specific (Eclipse Java Development Tools)
.classpath
# Java annotation processor (APT)
.factorypath
# PDT-specific (PHP Development Tools)
.buildpath
# sbteclipse plugin
.target
# Tern plugin
.tern-project
# TeXlipse plugin
.texlipse
# STS (Spring Tool Suite)
.springBeans
# Code Recommenders
.recommenders/

View File

@ -0,0 +1,25 @@
sudo: false
language: go
# * github.com/grpc/grpc-go still supports go1.6
# - When we drop support for go1.6 we can remove golang.org/x/net/context
# below as it is part of the Go std library since go1.7
# * github.com/prometheus/client_golang already requires at least go1.7 since
# September 2017
go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- master
install:
- go get github.com/prometheus/client_golang/prometheus
- go get google.golang.org/grpc
- go get golang.org/x/net/context
- go get github.com/stretchr/testify
script:
- make test
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -0,0 +1,24 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.2.0](https://github.com/grpc-ecosystem/go-grpc-prometheus/releases/tag/v1.2.0) - 2018-06-04
### Added
* Provide metrics object as `prometheus.Collector`, for conventional metric registration.
* Support non-default/global Prometheus registry.
* Allow configuring counters with `prometheus.CounterOpts`.
### Changed
* Remove usage of deprecated `grpc.Code()`.
* Remove usage of deprecated `grpc.Errorf` and replace with `status.Errorf`.
---
This changelog was started with version `v1.2.0`, for earlier versions refer to the respective [GitHub releases](https://github.com/grpc-ecosystem/go-grpc-prometheus/releases).

View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,247 @@
# Go gRPC Interceptors for Prometheus monitoring
[![Travis Build](https://travis-ci.org/grpc-ecosystem/go-grpc-prometheus.svg)](https://travis-ci.org/grpc-ecosystem/go-grpc-prometheus)
[![Go Report Card](https://goreportcard.com/badge/github.com/grpc-ecosystem/go-grpc-prometheus)](http://goreportcard.com/report/grpc-ecosystem/go-grpc-prometheus)
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/grpc-ecosystem/go-grpc-prometheus)
[![SourceGraph](https://sourcegraph.com/github.com/grpc-ecosystem/go-grpc-prometheus/-/badge.svg)](https://sourcegraph.com/github.com/grpc-ecosystem/go-grpc-prometheus/?badge)
[![codecov](https://codecov.io/gh/grpc-ecosystem/go-grpc-prometheus/branch/master/graph/badge.svg)](https://codecov.io/gh/grpc-ecosystem/go-grpc-prometheus)
[![Apache 2.0 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
[Prometheus](https://prometheus.io/) monitoring for your [gRPC Go](https://github.com/grpc/grpc-go) servers and clients.
A sister implementation for [gRPC Java](https://github.com/grpc/grpc-java) (same metrics, same semantics) is in [grpc-ecosystem/java-grpc-prometheus](https://github.com/grpc-ecosystem/java-grpc-prometheus).
## Interceptors
[gRPC Go](https://github.com/grpc/grpc-go) recently acquired support for Interceptors, i.e. middleware that is executed
by a gRPC Server before the request is passed onto the user's application logic. It is a perfect way to implement
common patterns: auth, logging and... monitoring.
To use Interceptors in chains, please see [`go-grpc-middleware`](https://github.com/mwitkow/go-grpc-middleware).
## Usage
There are two types of interceptors: client-side and server-side. This package provides monitoring Interceptors for both.
### Server-side
```go
import "github.com/grpc-ecosystem/go-grpc-prometheus"
...
// Initialize your gRPC server's interceptor.
myServer := grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
)
// Register your gRPC service implementations.
myservice.RegisterMyServiceServer(s.server, &myServiceImpl{})
// After all your registrations, make sure all of the Prometheus metrics are initialized.
grpc_prometheus.Register(myServer)
// Register Prometheus metrics handler.
http.Handle("/metrics", promhttp.Handler())
...
```
### Client-side
```go
import "github.com/grpc-ecosystem/go-grpc-prometheus"
...
clientConn, err = grpc.Dial(
address,
grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor)
)
client = pb_testproto.NewTestServiceClient(clientConn)
resp, err := client.PingEmpty(s.ctx, &myservice.Request{Msg: "hello"})
...
```
# Metrics
## Labels
All server-side metrics start with `grpc_server` as Prometheus subsystem name. All client-side metrics start with `grpc_client`. Both of them have mirror-concepts. Similarly all methods
contain the same rich labels:
* `grpc_service` - the [gRPC service](http://www.grpc.io/docs/#defining-a-service) name, which is the combination of protobuf `package` and
the `grpc_service` section name. E.g. for `package = mwitkow.testproto` and
`service TestService` the label will be `grpc_service="mwitkow.testproto.TestService"`
* `grpc_method` - the name of the method called on the gRPC service. E.g.
`grpc_method="Ping"`
* `grpc_type` - the gRPC [type of request](http://www.grpc.io/docs/guides/concepts.html#rpc-life-cycle).
Differentiating between the two is important especially for latency measurements.
- `unary` is single request, single response RPC
- `client_stream` is a multi-request, single response RPC
- `server_stream` is a single request, multi-response RPC
- `bidi_stream` is a multi-request, multi-response RPC
Additionally for completed RPCs, the following labels are used:
* `grpc_code` - the human-readable [gRPC status code](https://github.com/grpc/grpc-go/blob/master/codes/codes.go).
The list of all statuses is to long, but here are some common ones:
- `OK` - means the RPC was successful
- `IllegalArgument` - RPC contained bad values
- `Internal` - server-side error not disclosed to the clients
## Counters
The counters and their up to date documentation is in [server_reporter.go](server_reporter.go) and [client_reporter.go](client_reporter.go)
the respective Prometheus handler (usually `/metrics`).
For the purpose of this documentation we will only discuss `grpc_server` metrics. The `grpc_client` ones contain mirror concepts.
For simplicity, let's assume we're tracking a single server-side RPC call of [`mwitkow.testproto.TestService`](examples/testproto/test.proto),
calling the method `PingList`. The call succeeds and returns 20 messages in the stream.
First, immediately after the server receives the call it will increment the
`grpc_server_started_total` and start the handling time clock (if histograms are enabled).
```jsoniq
grpc_server_started_total{grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 1
```
Then the user logic gets invoked. It receives one message from the client containing the request
(it's a `server_stream`):
```jsoniq
grpc_server_msg_received_total{grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 1
```
The user logic may return an error, or send multiple messages back to the client. In this case, on
each of the 20 messages sent back, a counter will be incremented:
```jsoniq
grpc_server_msg_sent_total{grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 20
```
After the call completes, its status (`OK` or other [gRPC status code](https://github.com/grpc/grpc-go/blob/master/codes/codes.go))
and the relevant call labels increment the `grpc_server_handled_total` counter.
```jsoniq
grpc_server_handled_total{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 1
```
## Histograms
[Prometheus histograms](https://prometheus.io/docs/concepts/metric_types/#histogram) are a great way
to measure latency distributions of your RPCs. However, since it is bad practice to have metrics
of [high cardinality](https://prometheus.io/docs/practices/instrumentation/#do-not-overuse-labels)
the latency monitoring metrics are disabled by default. To enable them please call the following
in your server initialization code:
```jsoniq
grpc_prometheus.EnableHandlingTimeHistogram()
```
After the call completes, its handling time will be recorded in a [Prometheus histogram](https://prometheus.io/docs/concepts/metric_types/#histogram)
variable `grpc_server_handling_seconds`. The histogram variable contains three sub-metrics:
* `grpc_server_handling_seconds_count` - the count of all completed RPCs by status and method
* `grpc_server_handling_seconds_sum` - cumulative time of RPCs by status and method, useful for
calculating average handling times
* `grpc_server_handling_seconds_bucket` - contains the counts of RPCs by status and method in respective
handling-time buckets. These buckets can be used by Prometheus to estimate SLAs (see [here](https://prometheus.io/docs/practices/histograms/))
The counter values will look as follows:
```jsoniq
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.005"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.01"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.025"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.05"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.1"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.25"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="0.5"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="1"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="2.5"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="5"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="10"} 1
grpc_server_handling_seconds_bucket{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream",le="+Inf"} 1
grpc_server_handling_seconds_sum{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 0.0003866430000000001
grpc_server_handling_seconds_count{grpc_code="OK",grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 1
```
## Useful query examples
Prometheus philosophy is to provide raw metrics to the monitoring system, and
let the aggregations be handled there. The verbosity of above metrics make it possible to have that
flexibility. Here's a couple of useful monitoring queries:
### request inbound rate
```jsoniq
sum(rate(grpc_server_started_total{job="foo"}[1m])) by (grpc_service)
```
For `job="foo"` (common label to differentiate between Prometheus monitoring targets), calculate the
rate of requests per second (1 minute window) for each gRPC `grpc_service` that the job has. Please note
how the `grpc_method` is being omitted here: all methods of a given gRPC service will be summed together.
### unary request error rate
```jsoniq
sum(rate(grpc_server_handled_total{job="foo",grpc_type="unary",grpc_code!="OK"}[1m])) by (grpc_service)
```
For `job="foo"`, calculate the per-`grpc_service` rate of `unary` (1:1) RPCs that failed, i.e. the
ones that didn't finish with `OK` code.
### unary request error percentage
```jsoniq
sum(rate(grpc_server_handled_total{job="foo",grpc_type="unary",grpc_code!="OK"}[1m])) by (grpc_service)
/
sum(rate(grpc_server_started_total{job="foo",grpc_type="unary"}[1m])) by (grpc_service)
* 100.0
```
For `job="foo"`, calculate the percentage of failed requests by service. It's easy to notice that
this is a combination of the two above examples. This is an example of a query you would like to
[alert on](https://prometheus.io/docs/alerting/rules/) in your system for SLA violations, e.g.
"no more than 1% requests should fail".
### average response stream size
```jsoniq
sum(rate(grpc_server_msg_sent_total{job="foo",grpc_type="server_stream"}[10m])) by (grpc_service)
/
sum(rate(grpc_server_started_total{job="foo",grpc_type="server_stream"}[10m])) by (grpc_service)
```
For `job="foo"` what is the `grpc_service`-wide `10m` average of messages returned for all `
server_stream` RPCs. This allows you to track the stream sizes returned by your system, e.g. allows
you to track when clients started to send "wide" queries that ret
Note the divisor is the number of started RPCs, in order to account for in-flight requests.
### 99%-tile latency of unary requests
```jsoniq
histogram_quantile(0.99,
sum(rate(grpc_server_handling_seconds_bucket{job="foo",grpc_type="unary"}[5m])) by (grpc_service,le)
)
```
For `job="foo"`, returns an 99%-tile [quantile estimation](https://prometheus.io/docs/practices/histograms/#quantiles)
of the handling time of RPCs per service. Please note the `5m` rate, this means that the quantile
estimation will take samples in a rolling `5m` window. When combined with other quantiles
(e.g. 50%, 90%), this query gives you tremendous insight into the responsiveness of your system
(e.g. impact of caching).
### percentage of slow unary queries (>250ms)
```jsoniq
100.0 - (
sum(rate(grpc_server_handling_seconds_bucket{job="foo",grpc_type="unary",le="0.25"}[5m])) by (grpc_service)
/
sum(rate(grpc_server_handling_seconds_count{job="foo",grpc_type="unary"}[5m])) by (grpc_service)
) * 100.0
```
For `job="foo"` calculate the by-`grpc_service` fraction of slow requests that took longer than `0.25`
seconds. This query is relatively complex, since the Prometheus aggregations use `le` (less or equal)
buckets, meaning that counting "fast" requests fractions is easier. However, simple maths helps.
This is an example of a query you would like to alert on in your system for SLA violations,
e.g. "less than 1% of requests are slower than 250ms".
## Status
This code has been used since August 2015 as the basis for monitoring of *production* gRPC micro services at [Improbable](https://improbable.io).
## License
`go-grpc-prometheus` is released under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details.

View File

@ -0,0 +1,39 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
// gRPC Prometheus monitoring interceptors for client-side gRPC.
package grpc_prometheus
import (
prom "github.com/prometheus/client_golang/prometheus"
)
var (
// DefaultClientMetrics is the default instance of ClientMetrics. It is
// intended to be used in conjunction the default Prometheus metrics
// registry.
DefaultClientMetrics = NewClientMetrics()
// UnaryClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Unary RPCs.
UnaryClientInterceptor = DefaultClientMetrics.UnaryClientInterceptor()
// StreamClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Streaming RPCs.
StreamClientInterceptor = DefaultClientMetrics.StreamClientInterceptor()
)
func init() {
prom.MustRegister(DefaultClientMetrics.clientStartedCounter)
prom.MustRegister(DefaultClientMetrics.clientHandledCounter)
prom.MustRegister(DefaultClientMetrics.clientStreamMsgReceived)
prom.MustRegister(DefaultClientMetrics.clientStreamMsgSent)
}
// EnableClientHandlingTimeHistogram turns on recording of handling time of
// RPCs. Histogram metrics can be very expensive for Prometheus to retain and
// query. This function acts on the DefaultClientMetrics variable and the
// default Prometheus metrics registry.
func EnableClientHandlingTimeHistogram(opts ...HistogramOption) {
DefaultClientMetrics.EnableClientHandlingTimeHistogram(opts...)
prom.Register(DefaultClientMetrics.clientHandledHistogram)
}

View File

@ -0,0 +1,170 @@
package grpc_prometheus
import (
"io"
prom "github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// ClientMetrics represents a collection of metrics to be registered on a
// Prometheus metrics registry for a gRPC client.
type ClientMetrics struct {
clientStartedCounter *prom.CounterVec
clientHandledCounter *prom.CounterVec
clientStreamMsgReceived *prom.CounterVec
clientStreamMsgSent *prom.CounterVec
clientHandledHistogramEnabled bool
clientHandledHistogramOpts prom.HistogramOpts
clientHandledHistogram *prom.HistogramVec
}
// NewClientMetrics returns a ClientMetrics object. Use a new instance of
// ClientMetrics when not using the default Prometheus metrics registry, for
// example when wanting to control which metrics are added to a registry as
// opposed to automatically adding metrics via init functions.
func NewClientMetrics(counterOpts ...CounterOption) *ClientMetrics {
opts := counterOptions(counterOpts)
return &ClientMetrics{
clientStartedCounter: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_client_started_total",
Help: "Total number of RPCs started on the client.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientHandledCounter: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_client_handled_total",
Help: "Total number of RPCs completed by the client, regardless of success or failure.",
}), []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}),
clientStreamMsgReceived: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_client_msg_received_total",
Help: "Total number of RPC stream messages received by the client.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientStreamMsgSent: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_client_msg_sent_total",
Help: "Total number of gRPC stream messages sent by the client.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientHandledHistogramEnabled: false,
clientHandledHistogramOpts: prom.HistogramOpts{
Name: "grpc_client_handling_seconds",
Help: "Histogram of response latency (seconds) of the gRPC until it is finished by the application.",
Buckets: prom.DefBuckets,
},
clientHandledHistogram: nil,
}
}
// Describe sends the super-set of all possible descriptors of metrics
// collected by this Collector to the provided channel and returns once
// the last descriptor has been sent.
func (m *ClientMetrics) Describe(ch chan<- *prom.Desc) {
m.clientStartedCounter.Describe(ch)
m.clientHandledCounter.Describe(ch)
m.clientStreamMsgReceived.Describe(ch)
m.clientStreamMsgSent.Describe(ch)
if m.clientHandledHistogramEnabled {
m.clientHandledHistogram.Describe(ch)
}
}
// Collect is called by the Prometheus registry when collecting
// metrics. The implementation sends each collected metric via the
// provided channel and returns once the last metric has been sent.
func (m *ClientMetrics) Collect(ch chan<- prom.Metric) {
m.clientStartedCounter.Collect(ch)
m.clientHandledCounter.Collect(ch)
m.clientStreamMsgReceived.Collect(ch)
m.clientStreamMsgSent.Collect(ch)
if m.clientHandledHistogramEnabled {
m.clientHandledHistogram.Collect(ch)
}
}
// EnableClientHandlingTimeHistogram turns on recording of handling time of RPCs.
// Histogram metrics can be very expensive for Prometheus to retain and query.
func (m *ClientMetrics) EnableClientHandlingTimeHistogram(opts ...HistogramOption) {
for _, o := range opts {
o(&m.clientHandledHistogramOpts)
}
if !m.clientHandledHistogramEnabled {
m.clientHandledHistogram = prom.NewHistogramVec(
m.clientHandledHistogramOpts,
[]string{"grpc_type", "grpc_service", "grpc_method"},
)
}
m.clientHandledHistogramEnabled = true
}
// UnaryClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Unary RPCs.
func (m *ClientMetrics) UnaryClientInterceptor() func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
monitor := newClientReporter(m, Unary, method)
monitor.SentMessage()
err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
monitor.ReceivedMessage()
}
st, _ := status.FromError(err)
monitor.Handled(st.Code())
return err
}
}
// StreamClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Streaming RPCs.
func (m *ClientMetrics) StreamClientInterceptor() func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
monitor := newClientReporter(m, clientStreamType(desc), method)
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
st, _ := status.FromError(err)
monitor.Handled(st.Code())
return nil, err
}
return &monitoredClientStream{clientStream, monitor}, nil
}
}
func clientStreamType(desc *grpc.StreamDesc) grpcType {
if desc.ClientStreams && !desc.ServerStreams {
return ClientStream
} else if !desc.ClientStreams && desc.ServerStreams {
return ServerStream
}
return BidiStream
}
// monitoredClientStream wraps grpc.ClientStream allowing each Sent/Recv of message to increment counters.
type monitoredClientStream struct {
grpc.ClientStream
monitor *clientReporter
}
func (s *monitoredClientStream) SendMsg(m interface{}) error {
err := s.ClientStream.SendMsg(m)
if err == nil {
s.monitor.SentMessage()
}
return err
}
func (s *monitoredClientStream) RecvMsg(m interface{}) error {
err := s.ClientStream.RecvMsg(m)
if err == nil {
s.monitor.ReceivedMessage()
} else if err == io.EOF {
s.monitor.Handled(codes.OK)
} else {
st, _ := status.FromError(err)
s.monitor.Handled(st.Code())
}
return err
}

View File

@ -0,0 +1,46 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_prometheus
import (
"time"
"google.golang.org/grpc/codes"
)
type clientReporter struct {
metrics *ClientMetrics
rpcType grpcType
serviceName string
methodName string
startTime time.Time
}
func newClientReporter(m *ClientMetrics, rpcType grpcType, fullMethod string) *clientReporter {
r := &clientReporter{
metrics: m,
rpcType: rpcType,
}
if r.metrics.clientHandledHistogramEnabled {
r.startTime = time.Now()
}
r.serviceName, r.methodName = splitMethodName(fullMethod)
r.metrics.clientStartedCounter.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
return r
}
func (r *clientReporter) ReceivedMessage() {
r.metrics.clientStreamMsgReceived.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
}
func (r *clientReporter) SentMessage() {
r.metrics.clientStreamMsgSent.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
}
func (r *clientReporter) Handled(code codes.Code) {
r.metrics.clientHandledCounter.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName, code.String()).Inc()
if r.metrics.clientHandledHistogramEnabled {
r.metrics.clientHandledHistogram.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Observe(time.Since(r.startTime).Seconds())
}
}

View File

@ -0,0 +1,16 @@
SHELL="/bin/bash"
GOFILES_NOVENDOR = $(shell go list ./... | grep -v /vendor/)
all: vet fmt test
fmt:
go fmt $(GOFILES_NOVENDOR)
vet:
go vet $(GOFILES_NOVENDOR)
test: vet
./scripts/test_all.sh
.PHONY: all vet test

View File

@ -0,0 +1,41 @@
package grpc_prometheus
import (
prom "github.com/prometheus/client_golang/prometheus"
)
// A CounterOption lets you add options to Counter metrics using With* funcs.
type CounterOption func(*prom.CounterOpts)
type counterOptions []CounterOption
func (co counterOptions) apply(o prom.CounterOpts) prom.CounterOpts {
for _, f := range co {
f(&o)
}
return o
}
// WithConstLabels allows you to add ConstLabels to Counter metrics.
func WithConstLabels(labels prom.Labels) CounterOption {
return func(o *prom.CounterOpts) {
o.ConstLabels = labels
}
}
// A HistogramOption lets you add options to Histogram metrics using With*
// funcs.
type HistogramOption func(*prom.HistogramOpts)
// WithHistogramBuckets allows you to specify custom bucket ranges for histograms if EnableHandlingTimeHistogram is on.
func WithHistogramBuckets(buckets []float64) HistogramOption {
return func(o *prom.HistogramOpts) { o.Buckets = buckets }
}
// WithHistogramConstLabels allows you to add custom ConstLabels to
// histograms metrics.
func WithHistogramConstLabels(labels prom.Labels) HistogramOption {
return func(o *prom.HistogramOpts) {
o.ConstLabels = labels
}
}

View File

@ -0,0 +1,48 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
// gRPC Prometheus monitoring interceptors for server-side gRPC.
package grpc_prometheus
import (
prom "github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
)
var (
// DefaultServerMetrics is the default instance of ServerMetrics. It is
// intended to be used in conjunction the default Prometheus metrics
// registry.
DefaultServerMetrics = NewServerMetrics()
// UnaryServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Unary RPCs.
UnaryServerInterceptor = DefaultServerMetrics.UnaryServerInterceptor()
// StreamServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Streaming RPCs.
StreamServerInterceptor = DefaultServerMetrics.StreamServerInterceptor()
)
func init() {
prom.MustRegister(DefaultServerMetrics.serverStartedCounter)
prom.MustRegister(DefaultServerMetrics.serverHandledCounter)
prom.MustRegister(DefaultServerMetrics.serverStreamMsgReceived)
prom.MustRegister(DefaultServerMetrics.serverStreamMsgSent)
}
// Register takes a gRPC server and pre-initializes all counters to 0. This
// allows for easier monitoring in Prometheus (no missing metrics), and should
// be called *after* all services have been registered with the server. This
// function acts on the DefaultServerMetrics variable.
func Register(server *grpc.Server) {
DefaultServerMetrics.InitializeMetrics(server)
}
// EnableHandlingTimeHistogram turns on recording of handling time
// of RPCs. Histogram metrics can be very expensive for Prometheus
// to retain and query. This function acts on the DefaultServerMetrics
// variable and the default Prometheus metrics registry.
func EnableHandlingTimeHistogram(opts ...HistogramOption) {
DefaultServerMetrics.EnableHandlingTimeHistogram(opts...)
prom.Register(DefaultServerMetrics.serverHandledHistogram)
}

View File

@ -0,0 +1,185 @@
package grpc_prometheus
import (
prom "github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
)
// ServerMetrics represents a collection of metrics to be registered on a
// Prometheus metrics registry for a gRPC server.
type ServerMetrics struct {
serverStartedCounter *prom.CounterVec
serverHandledCounter *prom.CounterVec
serverStreamMsgReceived *prom.CounterVec
serverStreamMsgSent *prom.CounterVec
serverHandledHistogramEnabled bool
serverHandledHistogramOpts prom.HistogramOpts
serverHandledHistogram *prom.HistogramVec
}
// NewServerMetrics returns a ServerMetrics object. Use a new instance of
// ServerMetrics when not using the default Prometheus metrics registry, for
// example when wanting to control which metrics are added to a registry as
// opposed to automatically adding metrics via init functions.
func NewServerMetrics(counterOpts ...CounterOption) *ServerMetrics {
opts := counterOptions(counterOpts)
return &ServerMetrics{
serverStartedCounter: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_server_started_total",
Help: "Total number of RPCs started on the server.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverHandledCounter: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_server_handled_total",
Help: "Total number of RPCs completed on the server, regardless of success or failure.",
}), []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}),
serverStreamMsgReceived: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_server_msg_received_total",
Help: "Total number of RPC stream messages received on the server.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverStreamMsgSent: prom.NewCounterVec(
opts.apply(prom.CounterOpts{
Name: "grpc_server_msg_sent_total",
Help: "Total number of gRPC stream messages sent by the server.",
}), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverHandledHistogramEnabled: false,
serverHandledHistogramOpts: prom.HistogramOpts{
Name: "grpc_server_handling_seconds",
Help: "Histogram of response latency (seconds) of gRPC that had been application-level handled by the server.",
Buckets: prom.DefBuckets,
},
serverHandledHistogram: nil,
}
}
// EnableHandlingTimeHistogram enables histograms being registered when
// registering the ServerMetrics on a Prometheus registry. Histograms can be
// expensive on Prometheus servers. It takes options to configure histogram
// options such as the defined buckets.
func (m *ServerMetrics) EnableHandlingTimeHistogram(opts ...HistogramOption) {
for _, o := range opts {
o(&m.serverHandledHistogramOpts)
}
if !m.serverHandledHistogramEnabled {
m.serverHandledHistogram = prom.NewHistogramVec(
m.serverHandledHistogramOpts,
[]string{"grpc_type", "grpc_service", "grpc_method"},
)
}
m.serverHandledHistogramEnabled = true
}
// Describe sends the super-set of all possible descriptors of metrics
// collected by this Collector to the provided channel and returns once
// the last descriptor has been sent.
func (m *ServerMetrics) Describe(ch chan<- *prom.Desc) {
m.serverStartedCounter.Describe(ch)
m.serverHandledCounter.Describe(ch)
m.serverStreamMsgReceived.Describe(ch)
m.serverStreamMsgSent.Describe(ch)
if m.serverHandledHistogramEnabled {
m.serverHandledHistogram.Describe(ch)
}
}
// Collect is called by the Prometheus registry when collecting
// metrics. The implementation sends each collected metric via the
// provided channel and returns once the last metric has been sent.
func (m *ServerMetrics) Collect(ch chan<- prom.Metric) {
m.serverStartedCounter.Collect(ch)
m.serverHandledCounter.Collect(ch)
m.serverStreamMsgReceived.Collect(ch)
m.serverStreamMsgSent.Collect(ch)
if m.serverHandledHistogramEnabled {
m.serverHandledHistogram.Collect(ch)
}
}
// UnaryServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Unary RPCs.
func (m *ServerMetrics) UnaryServerInterceptor() func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
monitor := newServerReporter(m, Unary, info.FullMethod)
monitor.ReceivedMessage()
resp, err := handler(ctx, req)
st, _ := status.FromError(err)
monitor.Handled(st.Code())
if err == nil {
monitor.SentMessage()
}
return resp, err
}
}
// StreamServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Streaming RPCs.
func (m *ServerMetrics) StreamServerInterceptor() func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
monitor := newServerReporter(m, streamRPCType(info), info.FullMethod)
err := handler(srv, &monitoredServerStream{ss, monitor})
st, _ := status.FromError(err)
monitor.Handled(st.Code())
return err
}
}
// InitializeMetrics initializes all metrics, with their appropriate null
// value, for all gRPC methods registered on a gRPC server. This is useful, to
// ensure that all metrics exist when collecting and querying.
func (m *ServerMetrics) InitializeMetrics(server *grpc.Server) {
serviceInfo := server.GetServiceInfo()
for serviceName, info := range serviceInfo {
for _, mInfo := range info.Methods {
preRegisterMethod(m, serviceName, &mInfo)
}
}
}
func streamRPCType(info *grpc.StreamServerInfo) grpcType {
if info.IsClientStream && !info.IsServerStream {
return ClientStream
} else if !info.IsClientStream && info.IsServerStream {
return ServerStream
}
return BidiStream
}
// monitoredStream wraps grpc.ServerStream allowing each Sent/Recv of message to increment counters.
type monitoredServerStream struct {
grpc.ServerStream
monitor *serverReporter
}
func (s *monitoredServerStream) SendMsg(m interface{}) error {
err := s.ServerStream.SendMsg(m)
if err == nil {
s.monitor.SentMessage()
}
return err
}
func (s *monitoredServerStream) RecvMsg(m interface{}) error {
err := s.ServerStream.RecvMsg(m)
if err == nil {
s.monitor.ReceivedMessage()
}
return err
}
// preRegisterMethod is invoked on Register of a Server, allowing all gRPC services labels to be pre-populated.
func preRegisterMethod(metrics *ServerMetrics, serviceName string, mInfo *grpc.MethodInfo) {
methodName := mInfo.Name
methodType := string(typeFromMethodInfo(mInfo))
// These are just references (no increments), as just referencing will create the labels but not set values.
metrics.serverStartedCounter.GetMetricWithLabelValues(methodType, serviceName, methodName)
metrics.serverStreamMsgReceived.GetMetricWithLabelValues(methodType, serviceName, methodName)
metrics.serverStreamMsgSent.GetMetricWithLabelValues(methodType, serviceName, methodName)
if metrics.serverHandledHistogramEnabled {
metrics.serverHandledHistogram.GetMetricWithLabelValues(methodType, serviceName, methodName)
}
for _, code := range allCodes {
metrics.serverHandledCounter.GetMetricWithLabelValues(methodType, serviceName, methodName, code.String())
}
}

View File

@ -0,0 +1,46 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_prometheus
import (
"time"
"google.golang.org/grpc/codes"
)
type serverReporter struct {
metrics *ServerMetrics
rpcType grpcType
serviceName string
methodName string
startTime time.Time
}
func newServerReporter(m *ServerMetrics, rpcType grpcType, fullMethod string) *serverReporter {
r := &serverReporter{
metrics: m,
rpcType: rpcType,
}
if r.metrics.serverHandledHistogramEnabled {
r.startTime = time.Now()
}
r.serviceName, r.methodName = splitMethodName(fullMethod)
r.metrics.serverStartedCounter.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
return r
}
func (r *serverReporter) ReceivedMessage() {
r.metrics.serverStreamMsgReceived.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
}
func (r *serverReporter) SentMessage() {
r.metrics.serverStreamMsgSent.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Inc()
}
func (r *serverReporter) Handled(code codes.Code) {
r.metrics.serverHandledCounter.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName, code.String()).Inc()
if r.metrics.serverHandledHistogramEnabled {
r.metrics.serverHandledHistogram.WithLabelValues(string(r.rpcType), r.serviceName, r.methodName).Observe(time.Since(r.startTime).Seconds())
}
}

View File

@ -0,0 +1,50 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_prometheus
import (
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
type grpcType string
const (
Unary grpcType = "unary"
ClientStream grpcType = "client_stream"
ServerStream grpcType = "server_stream"
BidiStream grpcType = "bidi_stream"
)
var (
allCodes = []codes.Code{
codes.OK, codes.Canceled, codes.Unknown, codes.InvalidArgument, codes.DeadlineExceeded, codes.NotFound,
codes.AlreadyExists, codes.PermissionDenied, codes.Unauthenticated, codes.ResourceExhausted,
codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.Unimplemented, codes.Internal,
codes.Unavailable, codes.DataLoss,
}
)
func splitMethodName(fullMethodName string) (string, string) {
fullMethodName = strings.TrimPrefix(fullMethodName, "/") // remove leading slash
if i := strings.Index(fullMethodName, "/"); i >= 0 {
return fullMethodName[:i], fullMethodName[i+1:]
}
return "unknown", "unknown"
}
func typeFromMethodInfo(mInfo *grpc.MethodInfo) grpcType {
if !mInfo.IsClientStream && !mInfo.IsServerStream {
return Unary
}
if mInfo.IsClientStream && !mInfo.IsServerStream {
return ClientStream
}
if !mInfo.IsClientStream && mInfo.IsServerStream {
return ServerStream
}
return BidiStream
}

View File

@ -0,0 +1,27 @@
Copyright (c) 2015, Gengo, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Gengo, Inc. nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,35 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "httprule",
srcs = [
"compile.go",
"parse.go",
"types.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule",
deps = ["//utilities"],
)
go_test(
name = "httprule_test",
size = "small",
srcs = [
"compile_test.go",
"parse_test.go",
"types_test.go",
],
embed = [":httprule"],
deps = [
"//utilities",
"@org_golang_google_grpc//grpclog",
],
)
alias(
name = "go_default_library",
actual = ":httprule",
visibility = ["//:__subpackages__"],
)

View File

@ -0,0 +1,121 @@
package httprule
import (
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
)
const (
opcodeVersion = 1
)
// Template is a compiled representation of path templates.
type Template struct {
// Version is the version number of the format.
Version int
// OpCodes is a sequence of operations.
OpCodes []int
// Pool is a constant pool
Pool []string
// Verb is a VERB part in the template.
Verb string
// Fields is a list of field paths bound in this template.
Fields []string
// Original template (example: /v1/a_bit_of_everything)
Template string
}
// Compiler compiles utilities representation of path templates into marshallable operations.
// They can be unmarshalled by runtime.NewPattern.
type Compiler interface {
Compile() Template
}
type op struct {
// code is the opcode of the operation
code utilities.OpCode
// str is a string operand of the code.
// num is ignored if str is not empty.
str string
// num is a numeric operand of the code.
num int
}
func (w wildcard) compile() []op {
return []op{
{code: utilities.OpPush},
}
}
func (w deepWildcard) compile() []op {
return []op{
{code: utilities.OpPushM},
}
}
func (l literal) compile() []op {
return []op{
{
code: utilities.OpLitPush,
str: string(l),
},
}
}
func (v variable) compile() []op {
var ops []op
for _, s := range v.segments {
ops = append(ops, s.compile()...)
}
ops = append(ops, op{
code: utilities.OpConcatN,
num: len(v.segments),
}, op{
code: utilities.OpCapture,
str: v.path,
})
return ops
}
func (t template) Compile() Template {
var rawOps []op
for _, s := range t.segments {
rawOps = append(rawOps, s.compile()...)
}
var (
ops []int
pool []string
fields []string
)
consts := make(map[string]int)
for _, op := range rawOps {
ops = append(ops, int(op.code))
if op.str == "" {
ops = append(ops, op.num)
} else {
// eof segment literal represents the "/" path pattern
if op.str == eof {
op.str = ""
}
if _, ok := consts[op.str]; !ok {
consts[op.str] = len(pool)
pool = append(pool, op.str)
}
ops = append(ops, consts[op.str])
}
if op.code == utilities.OpCapture {
fields = append(fields, op.str)
}
}
return Template{
Version: opcodeVersion,
OpCodes: ops,
Pool: pool,
Verb: t.verb,
Fields: fields,
Template: t.template,
}
}

View File

@ -0,0 +1,11 @@
//go:build gofuzz
// +build gofuzz
package httprule
func Fuzz(data []byte) int {
if _, err := Parse(string(data)); err != nil {
return 0
}
return 0
}

View File

@ -0,0 +1,368 @@
package httprule
import (
"errors"
"fmt"
"strings"
)
// InvalidTemplateError indicates that the path template is not valid.
type InvalidTemplateError struct {
tmpl string
msg string
}
func (e InvalidTemplateError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.tmpl)
}
// Parse parses the string representation of path template
func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
}
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens}
segs, err := p.topLevelSegments()
if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
}
return template{
segments: segs,
verb: verb,
template: tmpl,
}, nil
}
func tokenize(path string) (tokens []string, verb string) {
if path == "" {
return []string{eof}, ""
}
const (
init = iota
field
nested
)
st := init
for path != "" {
var idx int
switch st {
case init:
idx = strings.IndexAny(path, "/{")
case field:
idx = strings.IndexAny(path, ".=}")
case nested:
idx = strings.IndexAny(path, "/}")
}
if idx < 0 {
tokens = append(tokens, path)
break
}
switch r := path[idx]; r {
case '/', '.':
case '{':
st = field
case '=':
st = nested
case '}':
st = init
}
if idx == 0 {
tokens = append(tokens, path[idx:idx+1])
} else {
tokens = append(tokens, path[:idx], path[idx:idx+1])
}
path = path[idx+1:]
}
l := len(tokens)
// See
// https://github.com/grpc-ecosystem/grpc-gateway/pull/1947#issuecomment-774523693 ;
// although normal and backwards-compat logic here is to use the last index
// of a colon, if the final segment is a variable followed by a colon, the
// part following the colon must be a verb. Hence if the previous token is
// an end var marker, we switch the index we're looking for to Index instead
// of LastIndex, so that we correctly grab the remaining part of the path as
// the verb.
var penultimateTokenIsEndVar bool
switch l {
case 0, 1:
// Not enough to be variable so skip this logic and don't result in an
// invalid index
default:
penultimateTokenIsEndVar = tokens[l-2] == "}"
}
t := tokens[l-1]
var idx int
if penultimateTokenIsEndVar {
idx = strings.Index(t, ":")
} else {
idx = strings.LastIndex(t, ":")
}
if idx == 0 {
tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:]
}
tokens = append(tokens, eof)
return tokens, verb
}
// parser is a parser of the template syntax defined in github.com/googleapis/googleapis/google/api/http.proto.
type parser struct {
tokens []string
accepted []string
}
// topLevelSegments is the target of this parser.
func (p *parser) topLevelSegments() ([]segment, error) {
if _, err := p.accept(typeEOF); err == nil {
p.tokens = p.tokens[:0]
return []segment{literal(eof)}, nil
}
segs, err := p.segments()
if err != nil {
return nil, err
}
if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
}
return segs, nil
}
func (p *parser) segments() ([]segment, error) {
s, err := p.segment()
if err != nil {
return nil, err
}
segs := []segment{s}
for {
if _, err := p.accept("/"); err != nil {
return segs, nil
}
s, err := p.segment()
if err != nil {
return segs, err
}
segs = append(segs, s)
}
}
func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil {
return wildcard{}, nil
}
if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil
}
if l, err := p.literal(); err == nil {
return l, nil
}
v, err := p.variable()
if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %w", err)
}
return v, nil
}
func (p *parser) literal() (segment, error) {
lit, err := p.accept(typeLiteral)
if err != nil {
return nil, err
}
return literal(lit), nil
}
func (p *parser) variable() (segment, error) {
if _, err := p.accept("{"); err != nil {
return nil, err
}
path, err := p.fieldPath()
if err != nil {
return nil, err
}
var segs []segment
if _, err := p.accept("="); err == nil {
segs, err = p.segments()
if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %w", path, err)
}
} else {
segs = []segment{wildcard{}}
}
if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path)
}
return variable{
path: path,
segments: segs,
}, nil
}
func (p *parser) fieldPath() (string, error) {
c, err := p.accept(typeIdent)
if err != nil {
return "", err
}
components := []string{c}
for {
if _, err := p.accept("."); err != nil {
return strings.Join(components, "."), nil
}
c, err := p.accept(typeIdent)
if err != nil {
return "", fmt.Errorf("invalid field path component: %w", err)
}
components = append(components, c)
}
}
// A termType is a type of terminal symbols.
type termType string
// These constants define some of valid values of termType.
// They improve readability of parse functions.
//
// You can also use "/", "*", "**", "." or "=" as valid values.
const (
typeIdent = termType("ident")
typeLiteral = termType("literal")
typeEOF = termType("$")
)
// eof is the terminal symbol which always appears at the end of token sequence.
const eof = "\u0000"
// accept tries to accept a token in "p".
// This function consumes a token and returns it if it matches to the specified "term".
// If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0]
switch term {
case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) && t != "/" {
return "", fmt.Errorf("expected %q but got %q", term, t)
}
case typeEOF:
if t != eof {
return "", fmt.Errorf("expected EOF but got %q", t)
}
case typeIdent:
if err := expectIdent(t); err != nil {
return "", err
}
case typeLiteral:
if err := expectPChars(t); err != nil {
return "", err
}
default:
return "", fmt.Errorf("unknown termType %q", term)
}
p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t)
return t, nil
}
// expectPChars determines if "t" consists of only pchars defined in RFC3986.
//
// https://www.ietf.org/rfc/rfc3986.txt, P.49
//
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func expectPChars(t string) error {
const (
init = iota
pct1
pct2
)
st := init
for _, r := range t {
if st != init {
if !isHexDigit(r) {
return fmt.Errorf("invalid hexdigit: %c(%U)", r, r)
}
switch st {
case pct1:
st = pct2
case pct2:
st = init
}
continue
}
// unreserved
switch {
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case '0' <= r && r <= '9':
continue
}
switch r {
case '-', '.', '_', '~':
// unreserved
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=':
// sub-delims
case ':', '@':
// rest of pchar
case '%':
// pct-encoded
st = pct1
default:
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
}
}
if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t)
}
return nil
}
// expectIdent determines if "ident" is a valid identifier in .proto schema ([[:alpha:]_][[:alphanum:]_]*).
func expectIdent(ident string) error {
if ident == "" {
return errors.New("empty identifier")
}
for pos, r := range ident {
switch {
case '0' <= r && r <= '9':
if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident)
}
continue
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case r == '_':
continue
default:
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
}
}
return nil
}
func isHexDigit(r rune) bool {
switch {
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'F':
return true
case 'a' <= r && r <= 'f':
return true
}
return false
}

View File

@ -0,0 +1,60 @@
package httprule
import (
"fmt"
"strings"
)
type template struct {
segments []segment
verb string
template string
}
type segment interface {
fmt.Stringer
compile() (ops []op)
}
type wildcard struct{}
type deepWildcard struct{}
type literal string
type variable struct {
path string
segments []segment
}
func (wildcard) String() string {
return "*"
}
func (deepWildcard) String() string {
return "**"
}
func (l literal) String() string {
return string(l)
}
func (v variable) String() string {
var segs []string
for _, s := range v.segments {
segs = append(segs, s.String())
}
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
}
func (t template) String() string {
var segs []string
for _, s := range t.segments {
segs = append(segs, s.String())
}
str := strings.Join(segs, "/")
if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb)
}
return "/" + str
}

View File

@ -0,0 +1,97 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "runtime",
srcs = [
"context.go",
"convert.go",
"doc.go",
"errors.go",
"fieldmask.go",
"handler.go",
"marshal_httpbodyproto.go",
"marshal_json.go",
"marshal_jsonpb.go",
"marshal_proto.go",
"marshaler.go",
"marshaler_registry.go",
"mux.go",
"pattern.go",
"proto2_convert.go",
"query.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/runtime",
deps = [
"//internal/httprule",
"//utilities",
"@org_golang_google_genproto_googleapis_api//httpbody",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//grpclog",
"@org_golang_google_grpc//health/grpc_health_v1",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//reflect/protoregistry",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/fieldmaskpb",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
go_test(
name = "runtime_test",
size = "small",
srcs = [
"context_test.go",
"convert_test.go",
"errors_test.go",
"fieldmask_test.go",
"handler_test.go",
"marshal_httpbodyproto_test.go",
"marshal_json_test.go",
"marshal_jsonpb_test.go",
"marshal_proto_test.go",
"marshaler_registry_test.go",
"mux_internal_test.go",
"mux_test.go",
"pattern_test.go",
"query_fuzz_test.go",
"query_test.go",
],
embed = [":runtime"],
deps = [
"//runtime/internal/examplepb",
"//utilities",
"@com_github_google_go_cmp//cmp",
"@com_github_google_go_cmp//cmp/cmpopts",
"@org_golang_google_genproto_googleapis_api//httpbody",
"@org_golang_google_genproto_googleapis_rpc//errdetails",
"@org_golang_google_genproto_googleapis_rpc//status",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//health/grpc_health_v1",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//testing/protocmp",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/emptypb",
"@org_golang_google_protobuf//types/known/fieldmaskpb",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
alias(
name = "go_default_library",
actual = ":runtime",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,406 @@
package runtime
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// MetadataHeaderPrefix is the http prefix that represents custom metadata
// parameters to or from a gRPC call.
const MetadataHeaderPrefix = "Grpc-Metadata-"
// MetadataPrefix is prepended to permanent HTTP header keys (as specified
// by the IANA) when added to the gRPC context.
const MetadataPrefix = "grpcgateway-"
// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
// HTTP headers in a response handled by grpc-gateway
const MetadataTrailerPrefix = "Grpc-Trailer-"
const metadataGrpcTimeout = "Grpc-Timeout"
const metadataHeaderBinarySuffix = "-Bin"
const xForwardedFor = "X-Forwarded-For"
const xForwardedHost = "X-Forwarded-Host"
// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
// header isn't present. If the value is 0 the sent `context` will not have a timeout.
var DefaultContextTimeout = 0 * time.Second
// malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed.
// See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context.
var malformedHTTPHeaders = map[string]struct{}{
"connection": {},
}
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
AnnotateContextOption func(ctx context.Context) context.Context
)
func WithHTTPPathPattern(pattern string) AnnotateContextOption {
return func(ctx context.Context) context.Context {
return withHTTPPathPattern(ctx, pattern)
}
}
func decodeBinHeader(v string) ([]byte, error) {
if len(v)%4 == 0 {
// Input was padded, or padding was not necessary.
return base64.StdEncoding.DecodeString(v)
}
return base64.RawStdEncoding.DecodeString(v)
}
/*
AnnotateContext adds context information such as metadata from the request.
At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
except that the forwarded destination is not another HTTP service but rather
a gRPC service.
*/
func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewOutgoingContext(ctx, md), nil
}
// AnnotateIncomingContext adds context information such as metadata from the request.
// Attach metadata as incoming context.
func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewIncomingContext(ctx, md), nil
}
func isValidGRPCMetadataKey(key string) bool {
// Must be a valid gRPC "Header-Name" as defined here:
// https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
// This means 0-9 a-z _ - .
// Only lowercase letters are valid in the wire protocol, but the client library will normalize
// uppercase ASCII to lowercase, so uppercase ASCII is also acceptable.
bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode.
for _, ch := range bytes {
validLowercaseLetter := ch >= 'a' && ch <= 'z'
validUppercaseLetter := ch >= 'A' && ch <= 'Z'
validDigit := ch >= '0' && ch <= '9'
validOther := ch == '.' || ch == '-' || ch == '_'
if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther {
return false
}
}
return true
}
func isValidGRPCMetadataTextValue(textValue string) bool {
// Must be a valid gRPC "ASCII-Value" as defined here:
// https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
// This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive.
bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode.
for _, ch := range bytes {
if ch < 0x20 || ch > 0x7E {
return false
}
}
return true
}
func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
ctx = withRPCMethod(ctx, rpcMethodName)
for _, o := range options {
ctx = o(ctx)
}
timeout := DefaultContextTimeout
if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
var err error
timeout, err = timeoutDecode(tm)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
}
}
var pairs []string
for key, vals := range req.Header {
key = textproto.CanonicalMIMEHeaderKey(key)
switch key {
case xForwardedFor, xForwardedHost:
// Handled separately below
continue
}
for _, val := range vals {
// For backwards-compatibility, pass through 'authorization' header with no prefix.
if key == "Authorization" {
pairs = append(pairs, "authorization", val)
}
if h, ok := mux.incomingHeaderMatcher(key); ok {
if !isValidGRPCMetadataKey(h) {
grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h)
continue
}
// Handles "-bin" metadata in grpc, since grpc will do another base64
// encode before sending to server, we need to decode it first.
if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
b, err := decodeBinHeader(val)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
}
val = string(b)
} else if !isValidGRPCMetadataTextValue(val) {
grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h)
continue
}
pairs = append(pairs, h, val)
}
}
}
if host := req.Header.Get(xForwardedHost); host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), host)
} else if req.Host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}
xff := req.Header.Values(xForwardedFor)
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
xff = append(xff, remoteIP)
}
}
if len(xff) > 0 {
pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", "))
}
if timeout != 0 {
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
return ctx, nil, nil
}
md := metadata.Pairs(pairs...)
for _, mda := range mux.metadataAnnotators {
md = metadata.Join(md, mda(ctx, req))
}
return ctx, md, nil
}
// ServerMetadata consists of metadata sent from gRPC server.
type ServerMetadata struct {
HeaderMD metadata.MD
TrailerMD metadata.MD
}
type serverMetadataKey struct{}
// NewServerMetadataContext creates a new context with ServerMetadata
func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, serverMetadataKey{}, md)
}
// ServerMetadataFromContext returns the ServerMetadata in ctx
func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
if ctx == nil {
return md, false
}
md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
return
}
// ServerTransportStream implements grpc.ServerTransportStream.
// It should only be used by the generated files to support grpc.SendHeader
// outside of gRPC server use.
type ServerTransportStream struct {
mu sync.Mutex
header metadata.MD
trailer metadata.MD
}
// Method returns the method for the stream.
func (s *ServerTransportStream) Method() string {
return ""
}
// Header returns the header metadata of the stream.
func (s *ServerTransportStream) Header() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.header.Copy()
}
// SetHeader sets the header metadata.
func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.header = metadata.Join(s.header, md)
s.mu.Unlock()
return nil
}
// SendHeader sets the header metadata.
func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
return s.SetHeader(md)
}
// Trailer returns the cached trailer metadata.
func (s *ServerTransportStream) Trailer() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.trailer.Copy()
}
// SetTrailer sets the trailer metadata.
func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.mu.Unlock()
return nil
}
func timeoutDecode(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("timeout string is too short: %q", s)
}
d, ok := timeoutUnitToDuration(s[size-1])
if !ok {
return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
}
t, err := strconv.ParseInt(s[:size-1], 10, 64)
if err != nil {
return 0, err
}
return d * time.Duration(t), nil
}
func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
switch u {
case 'H':
return time.Hour, true
case 'M':
return time.Minute, true
case 'S':
return time.Second, true
case 'm':
return time.Millisecond, true
case 'u':
return time.Microsecond, true
case 'n':
return time.Nanosecond, true
default:
return
}
}
// isPermanentHTTPHeader checks whether hdr belongs to the list of
// permanent request headers maintained by IANA.
// http://www.iana.org/assignments/message-headers/message-headers.xml
func isPermanentHTTPHeader(hdr string) bool {
switch hdr {
case
"Accept",
"Accept-Charset",
"Accept-Language",
"Accept-Ranges",
"Authorization",
"Cache-Control",
"Content-Type",
"Cookie",
"Date",
"Expect",
"From",
"Host",
"If-Match",
"If-Modified-Since",
"If-None-Match",
"If-Schedule-Tag-Match",
"If-Unmodified-Since",
"Max-Forwards",
"Origin",
"Pragma",
"Referer",
"User-Agent",
"Via",
"Warning":
return true
}
return false
}
// isMalformedHTTPHeader checks whether header belongs to the list of
// "malformed headers" and would be rejected by the gRPC server.
func isMalformedHTTPHeader(header string) bool {
_, isMalformed := malformedHTTPHeaders[strings.ToLower(header)]
return isMalformed
}
// RPCMethod returns the method string for the server context. The returned
// string is in the format of "/package.service/method".
func RPCMethod(ctx context.Context) (string, bool) {
m := ctx.Value(rpcMethodKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
}
// HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
// The format of the returned string is defined by the google.api.http path template type.
func HTTPPathPattern(ctx context.Context) (string, bool) {
m := ctx.Value(httpPathPatternKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}

View File

@ -0,0 +1,318 @@
package runtime
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// String just returns the given string.
// It is just for compatibility to other types.
func String(val string) (string, error) {
return val, nil
}
// StringSlice converts 'val' where individual strings are separated by
// 'sep' into a string slice.
func StringSlice(val, sep string) ([]string, error) {
return strings.Split(val, sep), nil
}
// Bool converts the given string representation of a boolean value into bool.
func Bool(val string) (bool, error) {
return strconv.ParseBool(val)
}
// BoolSlice converts 'val' where individual booleans are separated by
// 'sep' into a bool slice.
func BoolSlice(val, sep string) ([]bool, error) {
s := strings.Split(val, sep)
values := make([]bool, len(s))
for i, v := range s {
value, err := Bool(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Float64 converts the given string representation into representation of a floating point number into float64.
func Float64(val string) (float64, error) {
return strconv.ParseFloat(val, 64)
}
// Float64Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float64 slice.
func Float64Slice(val, sep string) ([]float64, error) {
s := strings.Split(val, sep)
values := make([]float64, len(s))
for i, v := range s {
value, err := Float64(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Float32 converts the given string representation of a floating point number into float32.
func Float32(val string) (float32, error) {
f, err := strconv.ParseFloat(val, 32)
if err != nil {
return 0, err
}
return float32(f), nil
}
// Float32Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float32 slice.
func Float32Slice(val, sep string) ([]float32, error) {
s := strings.Split(val, sep)
values := make([]float32, len(s))
for i, v := range s {
value, err := Float32(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Int64 converts the given string representation of an integer into int64.
func Int64(val string) (int64, error) {
return strconv.ParseInt(val, 0, 64)
}
// Int64Slice converts 'val' where individual integers are separated by
// 'sep' into a int64 slice.
func Int64Slice(val, sep string) ([]int64, error) {
s := strings.Split(val, sep)
values := make([]int64, len(s))
for i, v := range s {
value, err := Int64(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Int32 converts the given string representation of an integer into int32.
func Int32(val string) (int32, error) {
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return 0, err
}
return int32(i), nil
}
// Int32Slice converts 'val' where individual integers are separated by
// 'sep' into a int32 slice.
func Int32Slice(val, sep string) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Int32(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Uint64 converts the given string representation of an integer into uint64.
func Uint64(val string) (uint64, error) {
return strconv.ParseUint(val, 0, 64)
}
// Uint64Slice converts 'val' where individual integers are separated by
// 'sep' into a uint64 slice.
func Uint64Slice(val, sep string) ([]uint64, error) {
s := strings.Split(val, sep)
values := make([]uint64, len(s))
for i, v := range s {
value, err := Uint64(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Uint32 converts the given string representation of an integer into uint32.
func Uint32(val string) (uint32, error) {
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return 0, err
}
return uint32(i), nil
}
// Uint32Slice converts 'val' where individual integers are separated by
// 'sep' into a uint32 slice.
func Uint32Slice(val, sep string) ([]uint32, error) {
s := strings.Split(val, sep)
values := make([]uint32, len(s))
for i, v := range s {
value, err := Uint32(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Bytes converts the given string representation of a byte sequence into a slice of bytes
// A bytes sequence is encoded in URL-safe base64 without padding
func Bytes(val string) ([]byte, error) {
b, err := base64.StdEncoding.DecodeString(val)
if err != nil {
b, err = base64.URLEncoding.DecodeString(val)
if err != nil {
return nil, err
}
}
return b, nil
}
// BytesSlice converts 'val' where individual bytes sequences, encoded in URL-safe
// base64 without padding, are separated by 'sep' into a slice of bytes slices slice.
func BytesSlice(val, sep string) ([][]byte, error) {
s := strings.Split(val, sep)
values := make([][]byte, len(s))
for i, v := range s {
value, err := Bytes(v)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Timestamp converts the given RFC3339 formatted string into a timestamp.Timestamp.
func Timestamp(val string) (*timestamppb.Timestamp, error) {
var r timestamppb.Timestamp
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
if err := unmarshaler.Unmarshal([]byte(val), &r); err != nil {
return nil, err
}
return &r, nil
}
// Duration converts the given string into a timestamp.Duration.
func Duration(val string) (*durationpb.Duration, error) {
var r durationpb.Duration
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
if err := unmarshaler.Unmarshal([]byte(val), &r); err != nil {
return nil, err
}
return &r, nil
}
// Enum converts the given string into an int32 that should be type casted into the
// correct enum proto type.
func Enum(val string, enumValMap map[string]int32) (int32, error) {
e, ok := enumValMap[val]
if ok {
return e, nil
}
i, err := Int32(val)
if err != nil {
return 0, fmt.Errorf("%s is not valid", val)
}
for _, v := range enumValMap {
if v == i {
return i, nil
}
}
return 0, fmt.Errorf("%s is not valid", val)
}
// EnumSlice converts 'val' where individual enums are separated by 'sep'
// into a int32 slice. Each individual int32 should be type casted into the
// correct enum proto type.
func EnumSlice(val, sep string, enumValMap map[string]int32) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Enum(v, enumValMap)
if err != nil {
return nil, err
}
values[i] = value
}
return values, nil
}
// Support for google.protobuf.wrappers on top of primitive types
// StringValue well-known type support as wrapper around string type
func StringValue(val string) (*wrapperspb.StringValue, error) {
return wrapperspb.String(val), nil
}
// FloatValue well-known type support as wrapper around float32 type
func FloatValue(val string) (*wrapperspb.FloatValue, error) {
parsedVal, err := Float32(val)
return wrapperspb.Float(parsedVal), err
}
// DoubleValue well-known type support as wrapper around float64 type
func DoubleValue(val string) (*wrapperspb.DoubleValue, error) {
parsedVal, err := Float64(val)
return wrapperspb.Double(parsedVal), err
}
// BoolValue well-known type support as wrapper around bool type
func BoolValue(val string) (*wrapperspb.BoolValue, error) {
parsedVal, err := Bool(val)
return wrapperspb.Bool(parsedVal), err
}
// Int32Value well-known type support as wrapper around int32 type
func Int32Value(val string) (*wrapperspb.Int32Value, error) {
parsedVal, err := Int32(val)
return wrapperspb.Int32(parsedVal), err
}
// UInt32Value well-known type support as wrapper around uint32 type
func UInt32Value(val string) (*wrapperspb.UInt32Value, error) {
parsedVal, err := Uint32(val)
return wrapperspb.UInt32(parsedVal), err
}
// Int64Value well-known type support as wrapper around int64 type
func Int64Value(val string) (*wrapperspb.Int64Value, error) {
parsedVal, err := Int64(val)
return wrapperspb.Int64(parsedVal), err
}
// UInt64Value well-known type support as wrapper around uint64 type
func UInt64Value(val string) (*wrapperspb.UInt64Value, error) {
parsedVal, err := Uint64(val)
return wrapperspb.UInt64(parsedVal), err
}
// BytesValue well-known type support as wrapper around bytes[] type
func BytesValue(val string) (*wrapperspb.BytesValue, error) {
parsedVal, err := Bytes(val)
return wrapperspb.Bytes(parsedVal), err
}

View File

@ -0,0 +1,5 @@
/*
Package runtime contains runtime helper functions used by
servers which protoc-gen-grpc-gateway generates.
*/
package runtime

View File

@ -0,0 +1,181 @@
package runtime
import (
"context"
"errors"
"io"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)
// ErrorHandlerFunc is the signature used to configure error handling.
type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)
// StreamErrorHandlerFunc is the signature used to configure stream error handling.
type StreamErrorHandlerFunc func(context.Context, error) *status.Status
// RoutingErrorHandlerFunc is the signature used to configure error handling for routing errors.
type RoutingErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, int)
// HTTPStatusError is the error to use when needing to provide a different HTTP status code for an error
// passed to the DefaultRoutingErrorHandler.
type HTTPStatusError struct {
HTTPStatus int
Err error
}
func (e *HTTPStatusError) Error() string {
return e.Err.Error()
}
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func HTTPStatusFromCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return 499
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusGatewayTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusTooManyRequests
case codes.FailedPrecondition:
// Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
return http.StatusBadRequest
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
default:
grpclog.Warningf("Unknown gRPC error code: %v", code)
return http.StatusInternalServerError
}
}
// HTTPError uses the mux-configured error handler.
func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
mux.errorHandler(ctx, mux, marshaler, w, r, err)
}
// DefaultHTTPErrorHandler is the default error handler.
// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode.
// If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is
// intended to allow passing through of specific statuses via the function set via WithRoutingErrorHandler
// for the ServeMux constructor to handle edge cases which the standard mappings in HTTPStatusFromCode
// are insufficient for.
// If otherwise, it replies with http.StatusInternalServerError.
//
// The response body written by this function is a Status message marshaled by the Marshaler.
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
var customStatus *HTTPStatusError
if errors.As(err, &customStatus) {
err = customStatus.Err
}
s := status.Convert(err)
pb := s.Proto()
w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")
contentType := marshaler.ContentType(pb)
w.Header().Set("Content-Type", contentType)
if s.Code() == codes.Unauthenticated {
w.Header().Set("WWW-Authenticate", s.Message())
}
buf, merr := marshaler.Marshal(pb)
if merr != nil {
grpclog.Errorf("Failed to marshal error message %q: %v", s, merr)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallback); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Error("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(r)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
st := HTTPStatusFromCode(s.Code())
if customStatus != nil {
st = customStatus.HTTPStatus
}
w.WriteHeader(st)
if _, err := w.Write(buf); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, mux, md)
}
}
func DefaultStreamErrorHandler(_ context.Context, err error) *status.Status {
return status.Convert(err)
}
// DefaultRoutingErrorHandler is our default handler for routing errors.
// By default http error codes mapped on the following error codes:
//
// NotFound -> grpc.NotFound
// StatusBadRequest -> grpc.InvalidArgument
// MethodNotAllowed -> grpc.Unimplemented
// Other -> grpc.Internal, method is not expecting to be called for anything else
func DefaultRoutingErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
sterr := status.Error(codes.Internal, "Unexpected routing error")
switch httpStatus {
case http.StatusBadRequest:
sterr = status.Error(codes.InvalidArgument, http.StatusText(httpStatus))
case http.StatusMethodNotAllowed:
sterr = status.Error(codes.Unimplemented, http.StatusText(httpStatus))
case http.StatusNotFound:
sterr = status.Error(codes.NotFound, http.StatusText(httpStatus))
}
mux.errorHandler(ctx, mux, marshaler, w, r, sterr)
}

View File

@ -0,0 +1,168 @@
package runtime
import (
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
)
func getFieldByName(fields protoreflect.FieldDescriptors, name string) protoreflect.FieldDescriptor {
fd := fields.ByName(protoreflect.Name(name))
if fd != nil {
return fd
}
return fields.ByJSONName(name)
}
// FieldMaskFromRequestBody creates a FieldMask printing all complete paths from the JSON body.
func FieldMaskFromRequestBody(r io.Reader, msg proto.Message) (*field_mask.FieldMask, error) {
fm := &field_mask.FieldMask{}
var root interface{}
if err := json.NewDecoder(r).Decode(&root); err != nil {
if errors.Is(err, io.EOF) {
return fm, nil
}
return nil, err
}
queue := []fieldMaskPathItem{{node: root, msg: msg.ProtoReflect()}}
for len(queue) > 0 {
// dequeue an item
item := queue[0]
queue = queue[1:]
m, ok := item.node.(map[string]interface{})
switch {
case ok && len(m) > 0:
// if the item is an object, then enqueue all of its children
for k, v := range m {
if item.msg == nil {
return nil, errors.New("JSON structure did not match request type")
}
fd := getFieldByName(item.msg.Descriptor().Fields(), k)
if fd == nil {
return nil, fmt.Errorf("could not find field %q in %q", k, item.msg.Descriptor().FullName())
}
if isDynamicProtoMessage(fd.Message()) {
for _, p := range buildPathsBlindly(string(fd.FullName().Name()), v) {
newPath := p
if item.path != "" {
newPath = item.path + "." + newPath
}
queue = append(queue, fieldMaskPathItem{path: newPath})
}
continue
}
if isProtobufAnyMessage(fd.Message()) && !fd.IsList() {
_, hasTypeField := v.(map[string]interface{})["@type"]
if hasTypeField {
queue = append(queue, fieldMaskPathItem{path: k})
continue
} else {
return nil, fmt.Errorf("could not find field @type in %q in message %q", k, item.msg.Descriptor().FullName())
}
}
child := fieldMaskPathItem{
node: v,
}
if item.path == "" {
child.path = string(fd.FullName().Name())
} else {
child.path = item.path + "." + string(fd.FullName().Name())
}
switch {
case fd.IsList(), fd.IsMap():
// As per: https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/field_mask.proto#L85-L86
// Do not recurse into repeated fields. The repeated field goes on the end of the path and we stop.
fm.Paths = append(fm.Paths, child.path)
case fd.Message() != nil:
child.msg = item.msg.Get(fd).Message()
fallthrough
default:
queue = append(queue, child)
}
}
case ok && len(m) == 0:
fallthrough
case len(item.path) > 0:
// otherwise, it's a leaf node so print its path
fm.Paths = append(fm.Paths, item.path)
}
}
// Sort for deterministic output in the presence
// of repeated fields.
sort.Strings(fm.Paths)
return fm, nil
}
func isProtobufAnyMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Any")
}
func isDynamicProtoMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Struct" || md.FullName() == "google.protobuf.Value")
}
// buildPathsBlindly does not attempt to match proto field names to the
// json value keys. Instead it relies completely on the structure of
// the unmarshalled json contained within in.
// Returns a slice containing all subpaths with the root at the
// passed in name and json value.
func buildPathsBlindly(name string, in interface{}) []string {
m, ok := in.(map[string]interface{})
if !ok {
return []string{name}
}
var paths []string
queue := []fieldMaskPathItem{{path: name, node: m}}
for len(queue) > 0 {
cur := queue[0]
queue = queue[1:]
m, ok := cur.node.(map[string]interface{})
if !ok {
// This should never happen since we should always check that we only add
// nodes of type map[string]interface{} to the queue.
continue
}
for k, v := range m {
if mi, ok := v.(map[string]interface{}); ok {
queue = append(queue, fieldMaskPathItem{path: cur.path + "." + k, node: mi})
} else {
// This is not a struct, so there are no more levels to descend.
curPath := cur.path + "." + k
paths = append(paths, curPath)
}
}
}
return paths
}
// fieldMaskPathItem stores a in-progress deconstruction of a path for a fieldmask
type fieldMaskPathItem struct {
// the list of prior fields leading up to node connected by dots
path string
// a generic decoded json object the current item to inspect for further path extraction
node interface{}
// parent message
msg protoreflect.Message
}

View File

@ -0,0 +1,235 @@
package runtime
import (
"context"
"errors"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
rc := http.NewResponseController(w)
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Error("Failed to extract ServerMetadata from context")
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
handleForwardResponseServerMetadata(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var delimiter []byte
if d, ok := marshaler.(Delimited); ok {
delimiter = d.Delimiter()
} else {
delimiter = []byte("\n")
}
var wroteHeader bool
for {
resp, err := recv()
if errors.Is(err, io.EOF) {
return
}
if err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(resp))
}
var buf []byte
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
switch {
case resp == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": resp}
if rb, ok := resp.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}
buf, err = marshaler.Marshal(result)
}
if err != nil {
grpclog.Errorf("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}
if _, err := w.Write(buf); err != nil {
grpclog.Errorf("Failed to send response chunk: %v", err)
return
}
wroteHeader = true
if _, err := w.Write(delimiter); err != nil {
grpclog.Errorf("Failed to send delimiter chunk: %v", err)
return
}
err = rc.Flush()
if err != nil {
if errors.Is(err, http.ErrNotSupported) {
grpclog.Errorf("Flush not supported in %T", w)
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
return
}
grpclog.Errorf("Failed to flush response to client: %v", err)
return
}
}
}
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.HeaderMD {
if h, ok := mux.outgoingHeaderMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k := range md.TrailerMD {
if h, ok := mux.outgoingTrailerMatcher(k); ok {
w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
}
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.TrailerMD {
if h, ok := mux.outgoingTrailerMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
// responseBody interface contains method for getting field for marshaling to the response body
// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
type responseBody interface {
XXX_ResponseBody() interface{}
}
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Error("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(req)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
var err error
if rb, ok := resp.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(resp)
}
if err != nil {
grpclog.Errorf("Marshal error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
if !doForwardTrailers {
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
}
if _, err = w.Write(buf); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, mux, md)
}
}
func requestAcceptsTrailers(req *http.Request) bool {
te := req.Header.Get("TE")
return strings.Contains(strings.ToLower(te), "trailers")
}
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
if len(opts) == 0 {
return nil
}
for _, opt := range opts {
if err := opt(ctx, w, resp); err != nil {
grpclog.Errorf("Error handling ForwardResponseOptions: %v", err)
return err
}
}
return nil
}
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
st := mux.streamErrorHandler(ctx, err)
msg := errorChunk(st)
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(msg))
w.WriteHeader(HTTPStatusFromCode(st.Code()))
}
buf, err := marshaler.Marshal(msg)
if err != nil {
grpclog.Errorf("Failed to marshal an error: %v", err)
return
}
if _, err := w.Write(buf); err != nil {
grpclog.Errorf("Failed to notify error to client: %v", err)
return
}
if _, err := w.Write(delimiter); err != nil {
grpclog.Errorf("Failed to send delimiter chunk: %v", err)
return
}
}
func errorChunk(st *status.Status) map[string]proto.Message {
return map[string]proto.Message{"error": st.Proto()}
}

View File

@ -0,0 +1,32 @@
package runtime
import (
"google.golang.org/genproto/googleapis/api/httpbody"
)
// HTTPBodyMarshaler is a Marshaler which supports marshaling of a
// google.api.HttpBody message as the full response body if it is
// the actual message used as the response. If not, then this will
// simply fallback to the Marshaler specified as its default Marshaler.
type HTTPBodyMarshaler struct {
Marshaler
}
// ContentType returns its specified content type in case v is a
// google.api.HttpBody message, otherwise it will fall back to the default Marshalers
// content type.
func (h *HTTPBodyMarshaler) ContentType(v interface{}) string {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.GetContentType()
}
return h.Marshaler.ContentType(v)
}
// Marshal marshals "v" by returning the body bytes if v is a
// google.api.HttpBody message, otherwise it falls back to the default Marshaler.
func (h *HTTPBodyMarshaler) Marshal(v interface{}) ([]byte, error) {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.GetData(), nil
}
return h.Marshaler.Marshal(v)
}

View File

@ -0,0 +1,50 @@
package runtime
import (
"encoding/json"
"io"
)
// JSONBuiltin is a Marshaler which marshals/unmarshals into/from JSON
// with the standard "encoding/json" package of Golang.
// Although it is generally faster for simple proto messages than JSONPb,
// it does not support advanced features of protobuf, e.g. map, oneof, ....
//
// The NewEncoder and NewDecoder types return *json.Encoder and
// *json.Decoder respectively.
type JSONBuiltin struct{}
// ContentType always Returns "application/json".
func (*JSONBuiltin) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON
func (j *JSONBuiltin) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// MarshalIndent is like Marshal but applies Indent to format the output
func (j *JSONBuiltin) MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) {
return json.MarshalIndent(v, prefix, indent)
}
// Unmarshal unmarshals JSON data into "v".
func (j *JSONBuiltin) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONBuiltin) NewDecoder(r io.Reader) Decoder {
return json.NewDecoder(r)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONBuiltin) NewEncoder(w io.Writer) Encoder {
return json.NewEncoder(w)
}
// Delimiter for newline encoded JSON streams.
func (j *JSONBuiltin) Delimiter() []byte {
return []byte("\n")
}

View File

@ -0,0 +1,349 @@
package runtime
import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"
"strconv"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// JSONPb is a Marshaler which marshals/unmarshals into/from JSON
// with the "google.golang.org/protobuf/encoding/protojson" marshaler.
// It supports the full functionality of protobuf unlike JSONBuiltin.
//
// The NewDecoder method returns a DecoderWrapper, so the underlying
// *json.Decoder methods can be used.
type JSONPb struct {
protojson.MarshalOptions
protojson.UnmarshalOptions
}
// ContentType always returns "application/json".
func (*JSONPb) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON.
func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := j.marshalTo(&buf, v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
buf, err := j.marshalNonProtoField(v)
if err != nil {
return err
}
if j.Indent != "" {
b := &bytes.Buffer{}
if err := json.Indent(b, buf, "", j.Indent); err != nil {
return err
}
buf = b.Bytes()
}
_, err = w.Write(buf)
return err
}
b, err := j.MarshalOptions.Marshal(p)
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
var (
// protoMessageType is stored to prevent constant lookup of the same type at runtime.
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
// marshalNonProto marshals a non-message field of a protobuf message.
// This function does not correctly marshal arbitrary data structures into JSON,
// it is only capable of marshaling non-message field values of protobuf,
// i.e. primitive types, enums; pointers to primitives or enums; maps from
// integer/string types to primitives/enums/pointers to messages.
func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return []byte("null"), nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Slice {
if rv.IsNil() {
if j.EmitUnpopulated {
return []byte("[]"), nil
}
return []byte("null"), nil
}
if rv.Type().Elem().Implements(protoMessageType) {
var buf bytes.Buffer
if err := buf.WriteByte('['); err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
if err := buf.WriteByte(','); err != nil {
return nil, err
}
}
if err := j.marshalTo(&buf, rv.Index(i).Interface().(proto.Message)); err != nil {
return nil, err
}
}
if err := buf.WriteByte(']'); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
if rv.Type().Elem().Implements(typeProtoEnum) {
var buf bytes.Buffer
if err := buf.WriteByte('['); err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
if err := buf.WriteByte(','); err != nil {
return nil, err
}
}
var err error
if j.UseEnumNumbers {
_, err = buf.WriteString(strconv.FormatInt(rv.Index(i).Int(), 10))
} else {
_, err = buf.WriteString("\"" + rv.Index(i).Interface().(protoEnum).String() + "\"")
}
if err != nil {
return nil, err
}
}
if err := buf.WriteByte(']'); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
}
if rv.Kind() == reflect.Map {
m := make(map[string]*json.RawMessage)
for _, k := range rv.MapKeys() {
buf, err := j.Marshal(rv.MapIndex(k).Interface())
if err != nil {
return nil, err
}
m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
}
return json.Marshal(m)
}
if enum, ok := rv.Interface().(protoEnum); ok && !j.UseEnumNumbers {
return json.Marshal(enum.String())
}
return json.Marshal(rv.Interface())
}
// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, j.UnmarshalOptions, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{
Decoder: d,
UnmarshalOptions: j.UnmarshalOptions,
}
}
// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
protojson.UnmarshalOptions
}
// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
return EncoderFunc(func(v interface{}) error {
if err := j.marshalTo(w, v); err != nil {
return err
}
// mimic json.Encoder by adding a newline (makes output
// easier to read when it contains multiple encoded items)
_, err := w.Write(j.Delimiter())
return err
})
}
func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, unmarshaler, v)
}
func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, unmarshaler, v)
}
// Decode into bytes for marshalling
var b json.RawMessage
if err := d.Decode(&b); err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), p)
}
func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
// Decode into bytes for marshalling
var b json.RawMessage
if err := d.Decode(&b); err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), rv.Interface().(proto.Message))
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
conv, ok := convFromType[rv.Type().Key().Kind()]
if !ok {
return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
}
m := make(map[string]*json.RawMessage)
if err := d.Decode(&m); err != nil {
return err
}
for k, v := range m {
result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if v == nil {
null := json.RawMessage("null")
v = &null
}
if err := unmarshalJSONPb([]byte(*v), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
}
return nil
}
if rv.Kind() == reflect.Slice {
if rv.Type().Elem().Kind() == reflect.Uint8 {
var sl []byte
if err := d.Decode(&sl); err != nil {
return err
}
if sl != nil {
rv.SetBytes(sl)
}
return nil
}
var sl []json.RawMessage
if err := d.Decode(&sl); err != nil {
return err
}
if sl != nil {
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
}
for _, item := range sl {
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(item), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.Set(reflect.Append(rv, bv.Elem()))
}
return nil
}
if _, ok := rv.Interface().(protoEnum); ok {
var repr interface{}
if err := d.Decode(&repr); err != nil {
return err
}
switch v := repr.(type) {
case string:
// TODO(yugui) Should use proto.StructProperties?
return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface())
case float64:
rv.Set(reflect.ValueOf(int32(v)).Convert(rv.Type()))
return nil
default:
return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
}
}
return d.Decode(v)
}
type protoEnum interface {
fmt.Stringer
EnumDescriptor() ([]byte, []int)
}
var typeProtoEnum = reflect.TypeOf((*protoEnum)(nil)).Elem()
var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
// Delimiter for newline encoded JSON streams.
func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}
var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
reflect.Bool: reflect.ValueOf(Bool),
reflect.Float64: reflect.ValueOf(Float64),
reflect.Float32: reflect.ValueOf(Float32),
reflect.Int64: reflect.ValueOf(Int64),
reflect.Int32: reflect.ValueOf(Int32),
reflect.Uint64: reflect.ValueOf(Uint64),
reflect.Uint32: reflect.ValueOf(Uint32),
reflect.Slice: reflect.ValueOf(Bytes),
}
)

View File

@ -0,0 +1,60 @@
package runtime
import (
"errors"
"io"
"google.golang.org/protobuf/proto"
)
// ProtoMarshaller is a Marshaller which marshals/unmarshals into/from serialize proto bytes
type ProtoMarshaller struct{}
// ContentType always returns "application/octet-stream".
func (*ProtoMarshaller) ContentType(_ interface{}) string {
return "application/octet-stream"
}
// Marshal marshals "value" into Proto
func (*ProtoMarshaller) Marshal(value interface{}) ([]byte, error) {
message, ok := value.(proto.Message)
if !ok {
return nil, errors.New("unable to marshal non proto field")
}
return proto.Marshal(message)
}
// Unmarshal unmarshals proto "data" into "value"
func (*ProtoMarshaller) Unmarshal(data []byte, value interface{}) error {
message, ok := value.(proto.Message)
if !ok {
return errors.New("unable to unmarshal non proto field")
}
return proto.Unmarshal(data, message)
}
// NewDecoder returns a Decoder which reads proto stream from "reader".
func (marshaller *ProtoMarshaller) NewDecoder(reader io.Reader) Decoder {
return DecoderFunc(func(value interface{}) error {
buffer, err := io.ReadAll(reader)
if err != nil {
return err
}
return marshaller.Unmarshal(buffer, value)
})
}
// NewEncoder returns an Encoder which writes proto stream into "writer".
func (marshaller *ProtoMarshaller) NewEncoder(writer io.Writer) Encoder {
return EncoderFunc(func(value interface{}) error {
buffer, err := marshaller.Marshal(value)
if err != nil {
return err
}
if _, err := writer.Write(buffer); err != nil {
return err
}
return nil
})
}

View File

@ -0,0 +1,50 @@
package runtime
import (
"io"
)
// Marshaler defines a conversion between byte sequence and gRPC payloads / fields.
type Marshaler interface {
// Marshal marshals "v" into byte sequence.
Marshal(v interface{}) ([]byte, error)
// Unmarshal unmarshals "data" into "v".
// "v" must be a pointer value.
Unmarshal(data []byte, v interface{}) error
// NewDecoder returns a Decoder which reads byte sequence from "r".
NewDecoder(r io.Reader) Decoder
// NewEncoder returns an Encoder which writes bytes sequence into "w".
NewEncoder(w io.Writer) Encoder
// ContentType returns the Content-Type which this marshaler is responsible for.
// The parameter describes the type which is being marshalled, which can sometimes
// affect the content type returned.
ContentType(v interface{}) string
}
// Decoder decodes a byte sequence
type Decoder interface {
Decode(v interface{}) error
}
// Encoder encodes gRPC payloads / fields into byte sequence.
type Encoder interface {
Encode(v interface{}) error
}
// DecoderFunc adapts an decoder function into Decoder.
type DecoderFunc func(v interface{}) error
// Decode delegates invocations to the underlying function itself.
func (f DecoderFunc) Decode(v interface{}) error { return f(v) }
// EncoderFunc adapts an encoder function into Encoder
type EncoderFunc func(v interface{}) error
// Encode delegates invocations to the underlying function itself.
func (f EncoderFunc) Encode(v interface{}) error { return f(v) }
// Delimited defines the streaming delimiter.
type Delimited interface {
// Delimiter returns the record separator for the stream.
Delimiter() []byte
}

View File

@ -0,0 +1,109 @@
package runtime
import (
"errors"
"mime"
"net/http"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/encoding/protojson"
)
// MIMEWildcard is the fallback MIME type used for requests which do not match
// a registered MIME type.
const MIMEWildcard = "*"
var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
defaultMarshaler = &HTTPBodyMarshaler{
Marshaler: &JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}
)
// MarshalerForRequest returns the inbound/outbound marshalers for this request.
// It checks the registry on the ServeMux for the MIME type set by the Content-Type header.
// If it isn't set (or the request Content-Type is empty), checks for "*".
// If there are multiple Content-Type headers set, choose the first one that it can
// exactly match in the registry.
// Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler.
func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) {
for _, acceptVal := range r.Header[acceptHeader] {
if m, ok := mux.marshalers.mimeMap[acceptVal]; ok {
outbound = m
break
}
}
for _, contentTypeVal := range r.Header[contentTypeHeader] {
contentType, _, err := mime.ParseMediaType(contentTypeVal)
if err != nil {
grpclog.Errorf("Failed to parse Content-Type %s: %v", contentTypeVal, err)
continue
}
if m, ok := mux.marshalers.mimeMap[contentType]; ok {
inbound = m
break
}
}
if inbound == nil {
inbound = mux.marshalers.mimeMap[MIMEWildcard]
}
if outbound == nil {
outbound = inbound
}
return inbound, outbound
}
// marshalerRegistry is a mapping from MIME types to Marshalers.
type marshalerRegistry struct {
mimeMap map[string]Marshaler
}
// add adds a marshaler for a case-sensitive MIME type string ("*" to match any
// MIME type).
func (m marshalerRegistry) add(mime string, marshaler Marshaler) error {
if len(mime) == 0 {
return errors.New("empty MIME type")
}
m.mimeMap[mime] = marshaler
return nil
}
// makeMarshalerMIMERegistry returns a new registry of marshalers.
// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces.
//
// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler
// with a "application/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler
// with a "application/json" Content-Type.
// "*" can be used to match any Content-Type.
// This can be attached to a ServerMux with the marshaler option.
func makeMarshalerMIMERegistry() marshalerRegistry {
return marshalerRegistry{
mimeMap: map[string]Marshaler{
MIMEWildcard: defaultMarshaler,
},
}
}
// WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound
// Marshalers to a MIME type in mux.
func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption {
return func(mux *ServeMux) {
if err := mux.marshalers.add(mime, marshaler); err != nil {
panic(err)
}
}
}

View File

@ -0,0 +1,486 @@
package runtime
import (
"context"
"errors"
"fmt"
"net/http"
"net/textproto"
"regexp"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
type UnescapingMode int
const (
// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
// path string before doing any routing.
UnescapingModeLegacy UnescapingMode = iota
// UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570
// reserved characters.
UnescapingModeAllExceptReserved
// UnescapingModeAllExceptSlash unescapes URL path parameters except path
// separators, which will be left as "%2F".
UnescapingModeAllExceptSlash
// UnescapingModeAllCharacters unescapes all URL path parameters.
UnescapingModeAllCharacters
// UnescapingModeDefault is the default escaping type.
// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
// reference implementation
UnescapingModeDefault = UnescapingModeLegacy
)
var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
outgoingTrailerMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
}
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func(*ServeMux)
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
// http.ResponseWriter, and proto.Message before every forwarded response.
//
// The message may be nil in the case where just a header is being sent.
func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
}
}
// WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode
// for more information.
func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.unescapingMode = mode
}
}
// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
return func(serveMux *ServeMux) {
currentQueryParser = queryParameterParser
}
}
// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
type HeaderMatcherFunc func(string) (string, bool)
// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
// keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function.
// HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'.
// Other headers are not added to the gRPC metadata.
func DefaultHeaderMatcher(key string) (string, bool) {
switch key = textproto.CanonicalMIMEHeaderKey(key); {
case isPermanentHTTPHeader(key):
return MetadataPrefix + key, true
case strings.HasPrefix(key, MetadataHeaderPrefix):
return key[len(MetadataHeaderPrefix):], true
}
return "", false
}
func defaultOutgoingHeaderMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}
func defaultOutgoingTrailerMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
}
// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header.
func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
for _, header := range fn.matchedMalformedHeaders() {
grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
}
return func(mux *ServeMux) {
mux.incomingHeaderMatcher = fn
}
}
// matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server.
func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
if fn == nil {
return nil
}
headers := make([]string, 0)
for header := range malformedHTTPHeaders {
out, accept := fn(header)
if accept && isMalformedHTTPHeader(out) {
headers = append(headers, out)
}
}
return headers
}
// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return the modified header.
func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingHeaderMatcher = fn
}
}
// WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return the modified header.
func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingTrailerMatcher = fn
}
}
// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
// is reading token from cookie and adding it in gRPC context.
func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
}
}
// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
//
// This can be used to configure a custom error response.
func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.errorHandler = fn
}
}
// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}
// WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
//
// Method called for errors which can happen before gRPC route selected or executed.
// The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.routingErrorHandler = fn
}
}
// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.disablePathLengthFallback = true
}
}
// WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
// When called the handler will forward the request to the upstream grpc service health check (defined in the
// gRPC Health Checking Protocol).
//
// See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how
// to setup the protocol in the grpc server.
//
// If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest.
func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption {
return func(s *ServeMux) {
// error can be ignored since pattern is definitely valid
_ = s.HandlePath(
http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string,
) {
_, outboundMarshaler := MarshalerForRequest(s, r)
resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{
Service: r.URL.Query().Get("service"),
})
if err != nil {
s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
return
}
w.Header().Set("Content-Type", "application/json")
if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING {
switch resp.GetStatus() {
case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN:
err = status.Error(codes.Unavailable, resp.String())
case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN:
err = status.Error(codes.NotFound, resp.String())
}
s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
return
}
_ = outboundMarshaler.NewEncoder(w).Encode(resp)
})
}
}
// WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux.
//
// See WithHealthEndpointAt for the general implementation.
func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption {
return WithHealthEndpointAt(healthCheckClient, "/healthz")
}
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
errorHandler: DefaultHTTPErrorHandler,
streamErrorHandler: DefaultStreamErrorHandler,
routingErrorHandler: DefaultRoutingErrorHandler,
unescapingMode: UnescapingModeDefault,
}
for _, opt := range opts {
opt(serveMux)
}
if serveMux.incomingHeaderMatcher == nil {
serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
}
if serveMux.outgoingHeaderMatcher == nil {
serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
}
if serveMux.outgoingTrailerMatcher == nil {
serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
}
return serveMux
}
// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}
// HandlePath allows users to configure custom path handlers.
// refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
compiler, err := httprule.Parse(pathPattern)
if err != nil {
return fmt.Errorf("parsing path pattern: %w", err)
}
tp := compiler.Compile()
pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
if err != nil {
return fmt.Errorf("creating new pattern: %w", err)
}
s.Handle(meth, pattern, h)
return nil
}
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path.
func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := r.URL.Path
if !strings.HasPrefix(path, "/") {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
return
}
// TODO(v3): remove UnescapingModeLegacy
if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
path = r.URL.RawPath
}
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
r.Method = strings.ToUpper(override)
}
var pathComponents []string
// since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F"
// in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the
// path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default
// behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved
if s.unescapingMode == UnescapingModeAllCharacters {
pathComponents = encodedPathSplitter.Split(path[1:], -1)
} else {
pathComponents = strings.Split(path[1:], "/")
}
lastPathComponent := pathComponents[len(pathComponents)-1]
for _, h := range s.handlers[r.Method] {
// If the pattern has a verb, explicitly look for a suffix in the last
// component that matches a colon plus the verb. This allows us to
// handle some cases that otherwise can't be correctly handled by the
// former LastIndex case, such as when the verb literal itself contains
// a colon. This should work for all cases that have run through the
// parser because we know what verb we're looking for, however, there
// are still some cases that the parser itself cannot disambiguate. See
// the comment there if interested.
var verb string
patVerb := h.pat.Verb()
idx := -1
if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
idx = len(lastPathComponent) - len(patVerb) - 1
}
if idx == 0 {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
return
}
comps := make([]string, len(pathComponents))
copy(comps, pathComponents)
if idx > 0 {
comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
}
pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
h.h(w, r, pathParams)
return
}
// if no handler has found for the request, lookup for other methods
// to handle POST -> GET fallback if the request is subject to path
// length fallback.
// Note we are not eagerly checking the request here as we want to return the
// right HTTP status code, and we need to process the fallback candidates in
// order to do that.
for m, handlers := range s.handlers {
if m == r.Method {
continue
}
for _, h := range handlers {
var verb string
patVerb := h.pat.Verb()
idx := -1
if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
idx = len(lastPathComponent) - len(patVerb) - 1
}
comps := make([]string, len(pathComponents))
copy(comps, pathComponents)
if idx > 0 {
comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
}
pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
// Also, only consider POST -> GET fallbacks, and avoid falling back to
// potentially dangerous operations like DELETE.
if s.isPathLengthFallback(r) && m == http.MethodGet {
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
h.h(w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
return
}
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
}
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
return s.forwardResponseOptions
}
func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}

View File

@ -0,0 +1,381 @@
package runtime
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc/grpclog"
)
var (
// ErrNotMatch indicates that the given HTTP request path does not match to the pattern.
ErrNotMatch = errors.New("not match to the path pattern")
// ErrInvalidPattern indicates that the given definition of Pattern is not valid.
ErrInvalidPattern = errors.New("invalid pattern")
)
type MalformedSequenceError string
func (e MalformedSequenceError) Error() string {
return "malformed path escape " + strconv.Quote(string(e))
}
type op struct {
code utilities.OpCode
operand int
}
// Pattern is a template pattern of http request paths defined in
// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto
type Pattern struct {
// ops is a list of operations
ops []op
// pool is a constant pool indexed by the operands or vars.
pool []string
// vars is a list of variables names to be bound by this pattern
vars []string
// stacksize is the max depth of the stack
stacksize int
// tailLen is the length of the fixed-size segments after a deep wildcard
tailLen int
// verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part.
verb string
}
// NewPattern returns a new Pattern from the given definition values.
// "ops" is a sequence of op codes. "pool" is a constant pool.
// "verb" is the verb part of the pattern. It is empty if the pattern does not have the part.
// "version" must be 1 for now.
// It returns an error if the given definition is invalid.
func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) {
if version != 1 {
grpclog.Errorf("unsupported version: %d", version)
return Pattern{}, ErrInvalidPattern
}
l := len(ops)
if l%2 != 0 {
grpclog.Errorf("odd number of ops codes: %d", l)
return Pattern{}, ErrInvalidPattern
}
var (
typedOps []op
stack, maxstack int
tailLen int
pushMSeen bool
vars []string
)
for i := 0; i < l; i += 2 {
op := op{code: utilities.OpCode(ops[i]), operand: ops[i+1]}
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
if pushMSeen {
tailLen++
}
stack++
case utilities.OpPushM:
if pushMSeen {
grpclog.Error("pushM appears twice")
return Pattern{}, ErrInvalidPattern
}
pushMSeen = true
stack++
case utilities.OpLitPush:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Errorf("negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
if pushMSeen {
tailLen++
}
stack++
case utilities.OpConcatN:
if op.operand <= 0 {
grpclog.Errorf("negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
stack -= op.operand
if stack < 0 {
grpclog.Error("stack underflow")
return Pattern{}, ErrInvalidPattern
}
stack++
case utilities.OpCapture:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Errorf("variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
v := pool[op.operand]
op.operand = len(vars)
vars = append(vars, v)
stack--
if stack < 0 {
grpclog.Error("stack underflow")
return Pattern{}, ErrInvalidPattern
}
default:
grpclog.Errorf("invalid opcode: %d", op.code)
return Pattern{}, ErrInvalidPattern
}
if maxstack < stack {
maxstack = stack
}
typedOps = append(typedOps, op)
}
return Pattern{
ops: typedOps,
pool: pool,
vars: vars,
stacksize: maxstack,
tailLen: tailLen,
verb: verb,
}, nil
}
// MustPattern is a helper function which makes it easier to call NewPattern in variable initialization.
func MustPattern(p Pattern, err error) Pattern {
if err != nil {
grpclog.Fatalf("Pattern initialization failed: %v", err)
}
return p
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// MatchAndEscape will return an error if no Patterns matched or if a pattern
// matched but contained malformed escape sequences. If successful, the function
// returns a mapping from field paths to their captured values.
func (p Pattern) MatchAndEscape(components []string, verb string, unescapingMode UnescapingMode) (map[string]string, error) {
if p.verb != verb {
if p.verb != "" {
return nil, ErrNotMatch
}
if len(components) == 0 {
components = []string{":" + verb}
} else {
components = append([]string{}, components...)
components[len(components)-1] += ":" + verb
}
}
var pos int
stack := make([]string, 0, p.stacksize)
captured := make([]string, len(p.vars))
l := len(components)
for _, op := range p.ops {
var err error
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush, utilities.OpLitPush:
if pos >= l {
return nil, ErrNotMatch
}
c := components[pos]
if op.code == utilities.OpLitPush {
if lit := p.pool[op.operand]; c != lit {
return nil, ErrNotMatch
}
} else if op.code == utilities.OpPush {
if c, err = unescape(c, unescapingMode, false); err != nil {
return nil, err
}
}
stack = append(stack, c)
pos++
case utilities.OpPushM:
end := len(components)
if end < pos+p.tailLen {
return nil, ErrNotMatch
}
end -= p.tailLen
c := strings.Join(components[pos:end], "/")
if c, err = unescape(c, unescapingMode, true); err != nil {
return nil, err
}
stack = append(stack, c)
pos = end
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
captured[op.operand] = stack[n]
stack = stack[:n]
}
}
if pos < l {
return nil, ErrNotMatch
}
bindings := make(map[string]string)
for i, val := range captured {
bindings[p.vars[i]] = val
}
return bindings, nil
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// It will never perform per-component unescaping (see: UnescapingModeLegacy).
// MatchAndEscape will return an error if no Patterns matched. If successful,
// the function returns a mapping from field paths to their captured values.
//
// Deprecated: Use MatchAndEscape.
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
return p.MatchAndEscape(components, verb, UnescapingModeDefault)
}
// Verb returns the verb part of the Pattern.
func (p Pattern) Verb() string { return p.verb }
func (p Pattern) String() string {
var stack []string
for _, op := range p.ops {
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
stack = append(stack, "*")
case utilities.OpLitPush:
stack = append(stack, p.pool[op.operand])
case utilities.OpPushM:
stack = append(stack, "**")
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
stack[n] = fmt.Sprintf("{%s=%s}", p.vars[op.operand], stack[n])
}
}
segs := strings.Join(stack, "/")
if p.verb != "" {
return fmt.Sprintf("/%s:%s", segs, p.verb)
}
return "/" + segs
}
/*
* The following code is adopted and modified from Go's standard library
* and carries the attached license.
*
* Copyright 2009 The Go Authors. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*/
// ishex returns whether or not the given byte is a valid hex character
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func isRFC6570Reserved(c byte) bool {
switch c {
case '!', '#', '$', '&', '\'', '(', ')', '*',
'+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
return true
default:
return false
}
}
// unhex converts a hex point to the bit representation
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// shouldUnescapeWithMode returns true if the character is escapable with the
// given mode
func shouldUnescapeWithMode(c byte, mode UnescapingMode) bool {
switch mode {
case UnescapingModeAllExceptReserved:
if isRFC6570Reserved(c) {
return false
}
case UnescapingModeAllExceptSlash:
if c == '/' {
return false
}
case UnescapingModeAllCharacters:
return true
}
return true
}
// unescape unescapes a path string using the provided mode
func unescape(s string, mode UnescapingMode, multisegment bool) (string, error) {
// TODO(v3): remove UnescapingModeLegacy
if mode == UnescapingModeLegacy {
return s, nil
}
if !multisegment {
mode = UnescapingModeAllCharacters
}
// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
if s[i] == '%' {
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", MalformedSequenceError(s)
}
i += 3
} else {
i++
}
}
if n == 0 {
return s, nil
}
var t strings.Builder
t.Grow(len(s))
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
c := unhex(s[i+1])<<4 | unhex(s[i+2])
if shouldUnescapeWithMode(c, mode) {
t.WriteByte(c)
i += 2
continue
}
fallthrough
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}

View File

@ -0,0 +1,80 @@
package runtime
import (
"google.golang.org/protobuf/proto"
)
// StringP returns a pointer to a string whose pointee is same as the given string value.
func StringP(val string) (*string, error) {
return proto.String(val), nil
}
// BoolP parses the given string representation of a boolean value,
// and returns a pointer to a bool whose value is same as the parsed value.
func BoolP(val string) (*bool, error) {
b, err := Bool(val)
if err != nil {
return nil, err
}
return proto.Bool(b), nil
}
// Float64P parses the given string representation of a floating point number,
// and returns a pointer to a float64 whose value is same as the parsed number.
func Float64P(val string) (*float64, error) {
f, err := Float64(val)
if err != nil {
return nil, err
}
return proto.Float64(f), nil
}
// Float32P parses the given string representation of a floating point number,
// and returns a pointer to a float32 whose value is same as the parsed number.
func Float32P(val string) (*float32, error) {
f, err := Float32(val)
if err != nil {
return nil, err
}
return proto.Float32(f), nil
}
// Int64P parses the given string representation of an integer
// and returns a pointer to a int64 whose value is same as the parsed integer.
func Int64P(val string) (*int64, error) {
i, err := Int64(val)
if err != nil {
return nil, err
}
return proto.Int64(i), nil
}
// Int32P parses the given string representation of an integer
// and returns a pointer to a int32 whose value is same as the parsed integer.
func Int32P(val string) (*int32, error) {
i, err := Int32(val)
if err != nil {
return nil, err
}
return proto.Int32(i), err
}
// Uint64P parses the given string representation of an integer
// and returns a pointer to a uint64 whose value is same as the parsed integer.
func Uint64P(val string) (*uint64, error) {
i, err := Uint64(val)
if err != nil {
return nil, err
}
return proto.Uint64(i), err
}
// Uint32P parses the given string representation of an integer
// and returns a pointer to a uint32 whose value is same as the parsed integer.
func Uint32P(val string) (*uint32, error) {
i, err := Uint32(val)
if err != nil {
return nil, err
}
return proto.Uint32(i), err
}

View File

@ -0,0 +1,372 @@
package runtime
import (
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/durationpb"
field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
var currentQueryParser QueryParameterParser = &DefaultQueryParser{}
// QueryParameterParser defines interface for all query parameter parsers
type QueryParameterParser interface {
Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
}
// PopulateQueryParameters parses query parameters
// into "msg" using current query parser
func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
return currentQueryParser.Parse(msg, values, filter)
}
// DefaultQueryParser is a QueryParameterParser which implements the default
// query parameters parsing behavior.
//
// See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context.
type DefaultQueryParser struct{}
// Parse populates "values" into "msg".
// A value is ignored if its key starts with one of the elements in "filter".
func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
for key, values := range values {
if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 {
key = match[1]
values = append([]string{match[2]}, values...)
}
msgValue := msg.ProtoReflect()
fieldPath := normalizeFieldPath(msgValue, strings.Split(key, "."))
if filter.HasCommonPrefix(fieldPath) {
continue
}
if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil {
return err
}
}
return nil
}
// PopulateFieldFromPath sets a value in a nested Protobuf structure.
func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
fieldPath := strings.Split(fieldPathString, ".")
return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
}
func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string {
newFieldPath := make([]string, 0, len(fieldPath))
for i, fieldName := range fieldPath {
fields := msgValue.Descriptor().Fields()
fieldDesc := fields.ByTextName(fieldName)
if fieldDesc == nil {
fieldDesc = fields.ByJSONName(fieldName)
}
if fieldDesc == nil {
// return initial field path values if no matching message field was found
return fieldPath
}
newFieldPath = append(newFieldPath, string(fieldDesc.Name()))
// If this is the last element, we're done
if i == len(fieldPath)-1 {
break
}
// Only singular message fields are allowed
if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated {
return fieldPath
}
// Get the nested message
msgValue = msgValue.Get(fieldDesc).Message()
}
return newFieldPath
}
func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
if len(fieldPath) < 1 {
return errors.New("no field path")
}
if len(values) < 1 {
return errors.New("no value provided")
}
var fieldDescriptor protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := msgValue.Descriptor().Fields()
// Get field by name
fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
if fieldDescriptor == nil {
fieldDescriptor = fields.ByJSONName(fieldName)
if fieldDescriptor == nil {
// We're not returning an error here because this could just be
// an extra query parameter that isn't part of the request.
grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
return nil
}
}
// If this is the last element, we're done
if i == len(fieldPath)-1 {
break
}
// Only singular message fields are allowed
if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
return fmt.Errorf("invalid path: %q is not a message", fieldName)
}
// Get the nested message
msgValue = msgValue.Mutable(fieldDescriptor).Message()
}
// Check if oneof already set
if of := fieldDescriptor.ContainingOneof(); of != nil {
if f := msgValue.WhichOneof(of); f != nil {
return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
}
}
switch {
case fieldDescriptor.IsList():
return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
case fieldDescriptor.IsMap():
return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
}
if len(values) > 1 {
return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
}
return populateField(fieldDescriptor, msgValue, values[0])
}
func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
}
msgValue.Set(fieldDescriptor, v)
return nil
}
func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
for _, value := range values {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
}
list.Append(v)
}
return nil
}
func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
if len(values) != 2 {
return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
}
key, err := parseField(fieldDescriptor.MapKey(), values[0])
if err != nil {
return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
}
value, err := parseField(fieldDescriptor.MapValue(), values[1])
if err != nil {
return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
}
mp.Set(key.MapKey(), value)
return nil
}
func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBool(v), nil
case protoreflect.EnumKind:
enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
if err != nil {
if errors.Is(err, protoregistry.NotFound) {
return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
}
return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
}
// Look for enum by name
v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
if v == nil {
i, err := strconv.Atoi(value)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
// Look for enum by number
if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
}
return protoreflect.ValueOfEnum(v.Number()), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt32(int32(v)), nil
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt64(v), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint32(uint32(v)), nil
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint64(v), nil
case protoreflect.FloatKind:
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat32(float32(v)), nil
case protoreflect.DoubleKind:
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat64(v), nil
case protoreflect.StringKind:
return protoreflect.ValueOfString(value), nil
case protoreflect.BytesKind:
v, err := Bytes(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBytes(v), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return parseMessage(fieldDescriptor.Message(), value)
default:
panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
}
}
func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
var msg proto.Message
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return protoreflect.Value{}, err
}
msg = timestamppb.New(t)
case "google.protobuf.Duration":
d, err := time.ParseDuration(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = durationpb.New(d)
case "google.protobuf.DoubleValue":
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Double(v)
case "google.protobuf.FloatValue":
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Float(float32(v))
case "google.protobuf.Int64Value":
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Int64(v)
case "google.protobuf.Int32Value":
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Int32(int32(v))
case "google.protobuf.UInt64Value":
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.UInt64(v)
case "google.protobuf.UInt32Value":
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.UInt32(uint32(v))
case "google.protobuf.BoolValue":
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Bool(v)
case "google.protobuf.StringValue":
msg = wrapperspb.String(value)
case "google.protobuf.BytesValue":
v, err := Bytes(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Bytes(v)
case "google.protobuf.FieldMask":
fm := &field_mask.FieldMask{}
fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
msg = fm
case "google.protobuf.Value":
var v structpb.Value
if err := protojson.Unmarshal([]byte(value), &v); err != nil {
return protoreflect.Value{}, err
}
msg = &v
case "google.protobuf.Struct":
var v structpb.Struct
if err := protojson.Unmarshal([]byte(value), &v); err != nil {
return protoreflect.Value{}, err
}
msg = &v
default:
return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
}

View File

@ -0,0 +1,31 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "utilities",
srcs = [
"doc.go",
"pattern.go",
"readerfactory.go",
"string_array_flag.go",
"trie.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/utilities",
)
go_test(
name = "utilities_test",
size = "small",
srcs = [
"string_array_flag_test.go",
"trie_test.go",
],
deps = [":utilities"],
)
alias(
name = "go_default_library",
actual = ":utilities",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,2 @@
// Package utilities provides members for internal use in grpc-gateway.
package utilities

View File

@ -0,0 +1,22 @@
package utilities
// An OpCode is a opcode of compiled path patterns.
type OpCode int
// These constants are the valid values of OpCode.
const (
// OpNop does nothing
OpNop = OpCode(iota)
// OpPush pushes a component to stack
OpPush
// OpLitPush pushes a component to stack if it matches to the literal
OpLitPush
// OpPushM concatenates the remaining components and pushes it to stack
OpPushM
// OpConcatN pops N items from stack, concatenates them and pushes it back to stack
OpConcatN
// OpCapture pops an item and binds it to the variable
OpCapture
// OpEnd is the least positive invalid opcode.
OpEnd
)

View File

@ -0,0 +1,19 @@
package utilities
import (
"bytes"
"io"
)
// IOReaderFactory takes in an io.Reader and returns a function that will allow you to create a new reader that begins
// at the start of the stream
func IOReaderFactory(r io.Reader) (func() io.Reader, error) {
b, err := io.ReadAll(r)
if err != nil {
return nil, err
}
return func() io.Reader {
return bytes.NewReader(b)
}, nil
}

View File

@ -0,0 +1,33 @@
package utilities
import (
"flag"
"strings"
)
// flagInterface is an cut down interface to `flag`
type flagInterface interface {
Var(value flag.Value, name string, usage string)
}
// StringArrayFlag defines a flag with the specified name and usage string.
// The return value is the address of a `StringArrayFlags` variable that stores the repeated values of the flag.
func StringArrayFlag(f flagInterface, name string, usage string) *StringArrayFlags {
value := &StringArrayFlags{}
f.Var(value, name, usage)
return value
}
// StringArrayFlags is a wrapper of `[]string` to provider an interface for `flag.Var`
type StringArrayFlags []string
// String returns a string representation of `StringArrayFlags`
func (i *StringArrayFlags) String() string {
return strings.Join(*i, ",")
}
// Set appends a value to `StringArrayFlags`
func (i *StringArrayFlags) Set(value string) error {
*i = append(*i, value)
return nil
}

View File

@ -0,0 +1,174 @@
package utilities
import (
"sort"
)
// DoubleArray is a Double Array implementation of trie on sequences of strings.
type DoubleArray struct {
// Encoding keeps an encoding from string to int
Encoding map[string]int
// Base is the base array of Double Array
Base []int
// Check is the check array of Double Array
Check []int
}
// NewDoubleArray builds a DoubleArray from a set of sequences of strings.
func NewDoubleArray(seqs [][]string) *DoubleArray {
da := &DoubleArray{Encoding: make(map[string]int)}
if len(seqs) == 0 {
return da
}
encoded := registerTokens(da, seqs)
sort.Sort(byLex(encoded))
root := node{row: -1, col: -1, left: 0, right: len(encoded)}
addSeqs(da, encoded, 0, root)
for i := len(da.Base); i > 0; i-- {
if da.Check[i-1] != 0 {
da.Base = da.Base[:i]
da.Check = da.Check[:i]
break
}
}
return da
}
func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
var result [][]int
for _, seq := range seqs {
encoded := make([]int, 0, len(seq))
for _, token := range seq {
if _, ok := da.Encoding[token]; !ok {
da.Encoding[token] = len(da.Encoding)
}
encoded = append(encoded, da.Encoding[token])
}
result = append(result, encoded)
}
for i := range result {
result[i] = append(result[i], len(da.Encoding))
}
return result
}
type node struct {
row, col int
left, right int
}
func (n node) value(seqs [][]int) int {
return seqs[n.row][n.col]
}
func (n node) children(seqs [][]int) []*node {
var result []*node
lastVal := int(-1)
last := new(node)
for i := n.left; i < n.right; i++ {
if lastVal == seqs[i][n.col+1] {
continue
}
last.right = i
last = &node{
row: i,
col: n.col + 1,
left: i,
}
result = append(result, last)
}
last.right = n.right
return result
}
func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
ensureSize(da, pos)
children := n.children(seqs)
var i int
for i = 1; ; i++ {
ok := func() bool {
for _, child := range children {
code := child.value(seqs)
j := i + code
ensureSize(da, j)
if da.Check[j] != 0 {
return false
}
}
return true
}()
if ok {
break
}
}
da.Base[pos] = i
for _, child := range children {
code := child.value(seqs)
j := i + code
da.Check[j] = pos + 1
}
terminator := len(da.Encoding)
for _, child := range children {
code := child.value(seqs)
if code == terminator {
continue
}
j := i + code
addSeqs(da, seqs, j, *child)
}
}
func ensureSize(da *DoubleArray, i int) {
for i >= len(da.Base) {
da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
}
}
type byLex [][]int
func (l byLex) Len() int { return len(l) }
func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l byLex) Less(i, j int) bool {
si := l[i]
sj := l[j]
var k int
for k = 0; k < len(si) && k < len(sj); k++ {
if si[k] < sj[k] {
return true
}
if si[k] > sj[k] {
return false
}
}
return k < len(sj)
}
// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
if len(da.Base) == 0 {
return false
}
var i int
for _, t := range seq {
code, ok := da.Encoding[t]
if !ok {
break
}
j := da.Base[i] + code
if len(da.Check) <= j || da.Check[j] != i+1 {
break
}
i = j
}
j := da.Base[i] + len(da.Encoding)
if len(da.Check) <= j || da.Check[j] != i+1 {
return false
}
return true
}