clear vendor

This commit is contained in:
Mikaël Cluseau 2019-10-17 14:21:08 +11:00
parent c0fc7bbe3d
commit 49c73be97a
953 changed files with 0 additions and 203404 deletions

View File

@ -1,5 +0,0 @@
*.sublime-*
.DS_Store
*.swp
*.swo
tags

View File

@ -1,7 +0,0 @@
language: go
go:
- 1.4
- 1.5
- 1.6
- tip

View File

@ -1,12 +0,0 @@
Copyright (c) 2012, Martin Angers
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,187 +0,0 @@
# Purell
Purell is a tiny Go library to normalize URLs. It returns a pure URL. Pure-ell. Sanitizer and all. Yeah, I know...
Based on the [wikipedia paper][wiki] and the [RFC 3986 document][rfc].
[![build status](https://secure.travis-ci.org/PuerkitoBio/purell.png)](http://travis-ci.org/PuerkitoBio/purell)
## Install
`go get github.com/PuerkitoBio/purell`
## Changelog
* **2016-11-14 (v1.1.0)** : IDN: Conform to RFC 5895: Fold character width (thanks to @beeker1121).
* **2016-07-27 (v1.0.0)** : Normalize IDN to ASCII (thanks to @zenovich).
* **2015-02-08** : Add fix for relative paths issue ([PR #5][pr5]) and add fix for unnecessary encoding of reserved characters ([see issue #7][iss7]).
* **v0.2.0** : Add benchmarks, Attempt IDN support.
* **v0.1.0** : Initial release.
## Examples
From `example_test.go` (note that in your code, you would import "github.com/PuerkitoBio/purell", and would prefix references to its methods and constants with "purell."):
```go
package purell
import (
"fmt"
"net/url"
)
func ExampleNormalizeURLString() {
if normalized, err := NormalizeURLString("hTTp://someWEBsite.com:80/Amazing%3f/url/",
FlagLowercaseScheme|FlagLowercaseHost|FlagUppercaseEscapes); err != nil {
panic(err)
} else {
fmt.Print(normalized)
}
// Output: http://somewebsite.com:80/Amazing%3F/url/
}
func ExampleMustNormalizeURLString() {
normalized := MustNormalizeURLString("hTTpS://someWEBsite.com:443/Amazing%fa/url/",
FlagsUnsafeGreedy)
fmt.Print(normalized)
// Output: http://somewebsite.com/Amazing%FA/url
}
func ExampleNormalizeURL() {
if u, err := url.Parse("Http://SomeUrl.com:8080/a/b/.././c///g?c=3&a=1&b=9&c=0#target"); err != nil {
panic(err)
} else {
normalized := NormalizeURL(u, FlagsUsuallySafeGreedy|FlagRemoveDuplicateSlashes|FlagRemoveFragment)
fmt.Print(normalized)
}
// Output: http://someurl.com:8080/a/c/g?c=3&a=1&b=9&c=0
}
```
## API
As seen in the examples above, purell offers three methods, `NormalizeURLString(string, NormalizationFlags) (string, error)`, `MustNormalizeURLString(string, NormalizationFlags) (string)` and `NormalizeURL(*url.URL, NormalizationFlags) (string)`. They all normalize the provided URL based on the specified flags. Here are the available flags:
```go
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
```
For convenience, the set of flags `FlagsSafe`, `FlagsUsuallySafe[Greedy|NonGreedy]`, `FlagsUnsafe[Greedy|NonGreedy]` and `FlagsAll[Greedy|NonGreedy]` are provided for the similarly grouped normalizations on [wikipedia's URL normalization page][wiki]. You can add (using the bitwise OR `|` operator) or remove (using the bitwise AND NOT `&^` operator) individual flags from the sets if required, to build your own custom set.
The [full godoc reference is available on gopkgdoc][godoc].
Some things to note:
* `FlagDecodeUnnecessaryEscapes`, `FlagEncodeNecessaryEscapes`, `FlagUppercaseEscapes` and `FlagRemoveEmptyQuerySeparator` are always implicitly set, because internally, the URL string is parsed as an URL object, which automatically decodes unnecessary escapes, uppercases and encodes necessary ones, and removes empty query separators (an unnecessary `?` at the end of the url). So this operation cannot **not** be done. For this reason, `FlagRemoveEmptyQuerySeparator` (as well as the other three) has been included in the `FlagsSafe` convenience set, instead of `FlagsUnsafe`, where Wikipedia puts it.
* The `FlagDecodeUnnecessaryEscapes` decodes the following escapes (*from -> to*):
- %24 -> $
- %26 -> &
- %2B-%3B -> +,-./0123456789:;
- %3D -> =
- %40-%5A -> @ABCDEFGHIJKLMNOPQRSTUVWXYZ
- %5F -> _
- %61-%7A -> abcdefghijklmnopqrstuvwxyz
- %7E -> ~
* When the `NormalizeURL` function is used (passing an URL object), this source URL object is modified (that is, after the call, the URL object will be modified to reflect the normalization).
* The *replace IP with domain name* normalization (`http://208.77.188.166/ → http://www.example.com/`) is obviously not possible for a library without making some network requests. This is not implemented in purell.
* The *remove unused query string parameters* and *remove default query parameters* are also not implemented, since this is a very case-specific normalization, and it is quite trivial to do with an URL object.
### Safe vs Usually Safe vs Unsafe
Purell allows you to control the level of risk you take while normalizing an URL. You can aggressively normalize, play it totally safe, or anything in between.
Consider the following URL:
`HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
Normalizing with the `FlagsSafe` gives:
`https://www.root.com/toto/tE%1F///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
With the `FlagsUsuallySafeGreedy`:
`https://www.root.com/toto/tE%1F///a/c?z=3&w=2&a=4&w=1#invalid`
And with `FlagsUnsafeGreedy`:
`http://root.com/toto/tE%1F/a/c?a=4&w=1&w=2&z=3`
## TODOs
* Add a class/default instance to allow specifying custom directory index names? At the moment, removing directory index removes `(^|/)((?:default|index)\.\w{1,4})$`.
## Thanks / Contributions
@rogpeppe
@jehiah
@opennota
@pchristopher1275
@zenovich
@beeker1121
## License
The [BSD 3-Clause license][bsd].
[bsd]: http://opensource.org/licenses/BSD-3-Clause
[wiki]: http://en.wikipedia.org/wiki/URL_normalization
[rfc]: http://tools.ietf.org/html/rfc3986#section-6
[godoc]: http://go.pkgdoc.org/github.com/PuerkitoBio/purell
[pr5]: https://github.com/PuerkitoBio/purell/pull/5
[iss7]: https://github.com/PuerkitoBio/purell/issues/7

View File

@ -1,379 +0,0 @@
/*
Package purell offers URL normalization as described on the wikipedia page:
http://en.wikipedia.org/wiki/URL_normalization
*/
package purell
import (
"bytes"
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/PuerkitoBio/urlesc"
"golang.org/x/net/idna"
"golang.org/x/text/unicode/norm"
"golang.org/x/text/width"
)
// A set of normalization flags determines how a URL will
// be normalized.
type NormalizationFlags uint
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
const (
defaultHttpPort = ":80"
defaultHttpsPort = ":443"
)
// Regular expressions used by the normalizations
var rxPort = regexp.MustCompile(`(:\d+)/?$`)
var rxDirIndex = regexp.MustCompile(`(^|/)((?:default|index)\.\w{1,4})$`)
var rxDupSlashes = regexp.MustCompile(`/{2,}`)
var rxDWORDHost = regexp.MustCompile(`^(\d+)((?:\.+)?(?:\:\d*)?)$`)
var rxOctalHost = regexp.MustCompile(`^(0\d*)\.(0\d*)\.(0\d*)\.(0\d*)((?:\.+)?(?:\:\d*)?)$`)
var rxHexHost = regexp.MustCompile(`^0x([0-9A-Fa-f]+)((?:\.+)?(?:\:\d*)?)$`)
var rxHostDots = regexp.MustCompile(`^(.+?)(:\d+)?$`)
var rxEmptyPort = regexp.MustCompile(`:+$`)
// Map of flags to implementation function.
// FlagDecodeUnnecessaryEscapes has no action, since it is done automatically
// by parsing the string as an URL. Same for FlagUppercaseEscapes and FlagRemoveEmptyQuerySeparator.
// Since maps have undefined traversing order, make a slice of ordered keys
var flagsOrder = []NormalizationFlags{
FlagLowercaseScheme,
FlagLowercaseHost,
FlagRemoveDefaultPort,
FlagRemoveDirectoryIndex,
FlagRemoveDotSegments,
FlagRemoveFragment,
FlagForceHTTP, // Must be after remove default port (because https=443/http=80)
FlagRemoveDuplicateSlashes,
FlagRemoveWWW,
FlagAddWWW,
FlagSortQuery,
FlagDecodeDWORDHost,
FlagDecodeOctalHost,
FlagDecodeHexHost,
FlagRemoveUnnecessaryHostDots,
FlagRemoveEmptyPortSeparator,
FlagRemoveTrailingSlash, // These two (add/remove trailing slash) must be last
FlagAddTrailingSlash,
}
// ... and then the map, where order is unimportant
var flags = map[NormalizationFlags]func(*url.URL){
FlagLowercaseScheme: lowercaseScheme,
FlagLowercaseHost: lowercaseHost,
FlagRemoveDefaultPort: removeDefaultPort,
FlagRemoveDirectoryIndex: removeDirectoryIndex,
FlagRemoveDotSegments: removeDotSegments,
FlagRemoveFragment: removeFragment,
FlagForceHTTP: forceHTTP,
FlagRemoveDuplicateSlashes: removeDuplicateSlashes,
FlagRemoveWWW: removeWWW,
FlagAddWWW: addWWW,
FlagSortQuery: sortQuery,
FlagDecodeDWORDHost: decodeDWORDHost,
FlagDecodeOctalHost: decodeOctalHost,
FlagDecodeHexHost: decodeHexHost,
FlagRemoveUnnecessaryHostDots: removeUnncessaryHostDots,
FlagRemoveEmptyPortSeparator: removeEmptyPortSeparator,
FlagRemoveTrailingSlash: removeTrailingSlash,
FlagAddTrailingSlash: addTrailingSlash,
}
// MustNormalizeURLString returns the normalized string, and panics if an error occurs.
// It takes an URL string as input, as well as the normalization flags.
func MustNormalizeURLString(u string, f NormalizationFlags) string {
result, e := NormalizeURLString(u, f)
if e != nil {
panic(e)
}
return result
}
// NormalizeURLString returns the normalized string, or an error if it can't be parsed into an URL object.
// It takes an URL string as input, as well as the normalization flags.
func NormalizeURLString(u string, f NormalizationFlags) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
if f&FlagLowercaseHost == FlagLowercaseHost {
parsed.Host = strings.ToLower(parsed.Host)
}
// The idna package doesn't fully conform to RFC 5895
// (https://tools.ietf.org/html/rfc5895), so we do it here.
// Taken from Go 1.8 cycle source, courtesy of bradfitz.
// TODO: Remove when (if?) idna package conforms to RFC 5895.
parsed.Host = width.Fold.String(parsed.Host)
parsed.Host = norm.NFC.String(parsed.Host)
if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
return "", err
}
return NormalizeURL(parsed, f), nil
}
// NormalizeURL returns the normalized string.
// It takes a parsed URL object as input, as well as the normalization flags.
func NormalizeURL(u *url.URL, f NormalizationFlags) string {
for _, k := range flagsOrder {
if f&k == k {
flags[k](u)
}
}
return urlesc.Escape(u)
}
func lowercaseScheme(u *url.URL) {
if len(u.Scheme) > 0 {
u.Scheme = strings.ToLower(u.Scheme)
}
}
func lowercaseHost(u *url.URL) {
if len(u.Host) > 0 {
u.Host = strings.ToLower(u.Host)
}
}
func removeDefaultPort(u *url.URL) {
if len(u.Host) > 0 {
scheme := strings.ToLower(u.Scheme)
u.Host = rxPort.ReplaceAllStringFunc(u.Host, func(val string) string {
if (scheme == "http" && val == defaultHttpPort) || (scheme == "https" && val == defaultHttpsPort) {
return ""
}
return val
})
}
}
func removeTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if strings.HasSuffix(u.Path, "/") {
u.Path = u.Path[:l-1]
}
} else if l = len(u.Host); l > 0 {
if strings.HasSuffix(u.Host, "/") {
u.Host = u.Host[:l-1]
}
}
}
func addTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
} else if l = len(u.Host); l > 0 {
if !strings.HasSuffix(u.Host, "/") {
u.Host += "/"
}
}
}
func removeDotSegments(u *url.URL) {
if len(u.Path) > 0 {
var dotFree []string
var lastIsDot bool
sections := strings.Split(u.Path, "/")
for _, s := range sections {
if s == ".." {
if len(dotFree) > 0 {
dotFree = dotFree[:len(dotFree)-1]
}
} else if s != "." {
dotFree = append(dotFree, s)
}
lastIsDot = (s == "." || s == "..")
}
// Special case if host does not end with / and new path does not begin with /
u.Path = strings.Join(dotFree, "/")
if u.Host != "" && !strings.HasSuffix(u.Host, "/") && !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
// Special case if the last segment was a dot, make sure the path ends with a slash
if lastIsDot && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
}
}
func removeDirectoryIndex(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDirIndex.ReplaceAllString(u.Path, "$1")
}
}
func removeFragment(u *url.URL) {
u.Fragment = ""
}
func forceHTTP(u *url.URL) {
if strings.ToLower(u.Scheme) == "https" {
u.Scheme = "http"
}
}
func removeDuplicateSlashes(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDupSlashes.ReplaceAllString(u.Path, "/")
}
}
func removeWWW(u *url.URL) {
if len(u.Host) > 0 && strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = u.Host[4:]
}
}
func addWWW(u *url.URL) {
if len(u.Host) > 0 && !strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = "www." + u.Host
}
}
func sortQuery(u *url.URL) {
q := u.Query()
if len(q) > 0 {
arKeys := make([]string, len(q))
i := 0
for k, _ := range q {
arKeys[i] = k
i++
}
sort.Strings(arKeys)
buf := new(bytes.Buffer)
for _, k := range arKeys {
sort.Strings(q[k])
for _, v := range q[k] {
if buf.Len() > 0 {
buf.WriteRune('&')
}
buf.WriteString(fmt.Sprintf("%s=%s", k, urlesc.QueryEscape(v)))
}
}
// Rebuild the raw query string
u.RawQuery = buf.String()
}
}
func decodeDWORDHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxDWORDHost.FindStringSubmatch(u.Host); len(matches) > 2 {
var parts [4]int64
dword, _ := strconv.ParseInt(matches[1], 10, 0)
for i, shift := range []uint{24, 16, 8, 0} {
parts[i] = dword >> shift & 0xFF
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[2])
}
}
}
func decodeOctalHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxOctalHost.FindStringSubmatch(u.Host); len(matches) > 5 {
var parts [4]int64
for i := 1; i <= 4; i++ {
parts[i-1], _ = strconv.ParseInt(matches[i], 8, 0)
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[5])
}
}
}
func decodeHexHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHexHost.FindStringSubmatch(u.Host); len(matches) > 2 {
// Conversion is safe because of regex validation
parsed, _ := strconv.ParseInt(matches[1], 16, 0)
// Set host as DWORD (base 10) encoded host
u.Host = fmt.Sprintf("%d%s", parsed, matches[2])
// The rest is the same as decoding a DWORD host
decodeDWORDHost(u)
}
}
}
func removeUnncessaryHostDots(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHostDots.FindStringSubmatch(u.Host); len(matches) > 1 {
// Trim the leading and trailing dots
u.Host = strings.Trim(matches[1], ".")
if len(matches) > 2 {
u.Host += matches[2]
}
}
}
}
func removeEmptyPortSeparator(u *url.URL) {
if len(u.Host) > 0 {
u.Host = rxEmptyPort.ReplaceAllString(u.Host, "")
}
}

View File

@ -1,15 +0,0 @@
language: go
go:
- 1.4.x
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- tip
install:
- go build .
script:
- go test -v

View File

@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,16 +0,0 @@
urlesc [![Build Status](https://travis-ci.org/PuerkitoBio/urlesc.svg?branch=master)](https://travis-ci.org/PuerkitoBio/urlesc) [![GoDoc](http://godoc.org/github.com/PuerkitoBio/urlesc?status.svg)](http://godoc.org/github.com/PuerkitoBio/urlesc)
======
Package urlesc implements query escaping as per RFC 3986.
It contains some parts of the net/url package, modified so as to allow
some reserved characters incorrectly escaped by net/url (see [issue 5684](https://github.com/golang/go/issues/5684)).
## Install
go get github.com/PuerkitoBio/urlesc
## License
Go license (BSD-3-Clause)

View File

@ -1,180 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package urlesc implements query escaping as per RFC 3986.
// It contains some parts of the net/url package, modified so as to allow
// some reserved characters incorrectly escaped by net/url.
// See https://github.com/golang/go/issues/5684
package urlesc
import (
"bytes"
"net/url"
"strings"
)
type encoding int
const (
encodePath encoding = 1 + iota
encodeUserPassword
encodeQueryComponent
encodeFragment
)
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '.', '_', '~': // §2.3 Unreserved characters (mark)
return false
// §2.2 Reserved characters (reserved)
case ':', '/', '?', '#', '[', ']', '@', // gen-delims
'!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // sub-delims
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case encodePath: // §3.3
// The RFC allows sub-delims and : @.
// '/', '[' and ']' can be used to assign meaning to individual path
// segments. This package only manipulates the path as a whole,
// so we allow those as well. That leaves only ? and # to escape.
return c == '?' || c == '#'
case encodeUserPassword: // §3.2.1
// The RFC allows : and sub-delims in
// userinfo. The parsing of userinfo treats ':' as special so we must escape
// all the gen-delims.
return c == ':' || c == '/' || c == '?' || c == '#' || c == '[' || c == ']' || c == '@'
case encodeQueryComponent: // §3.4
// The RFC allows / and ?.
return c != '/' && c != '?'
case encodeFragment: // §4.1
// The RFC text is silent but the grammar allows
// everything, so escape nothing but #
return c == '#'
}
}
// Everything else must be escaped.
return true
}
// QueryEscape escapes the string so it can be safely placed
// inside a URL query.
func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
func escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == encodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == encodeQueryComponent:
t[j] = '+'
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}
var uiReplacer = strings.NewReplacer(
"%21", "!",
"%27", "'",
"%28", "(",
"%29", ")",
"%2A", "*",
)
// unescapeUserinfo unescapes some characters that need not to be escaped as per RFC3986.
func unescapeUserinfo(s string) string {
return uiReplacer.Replace(s)
}
// Escape reassembles the URL into a valid URL string.
// The general form of the result is one of:
//
// scheme:opaque
// scheme://userinfo@host/path?query#fragment
//
// If u.Opaque is non-empty, String uses the first form;
// otherwise it uses the second form.
//
// In the second form, the following rules apply:
// - if u.Scheme is empty, scheme: is omitted.
// - if u.User is nil, userinfo@ is omitted.
// - if u.Host is empty, host/ is omitted.
// - if u.Scheme and u.Host are empty and u.User is nil,
// the entire scheme://userinfo@host/ is omitted.
// - if u.Host is non-empty and u.Path begins with a /,
// the form host/path does not add its own /.
// - if u.RawQuery is empty, ?query is omitted.
// - if u.Fragment is empty, #fragment is omitted.
func Escape(u *url.URL) string {
var buf bytes.Buffer
if u.Scheme != "" {
buf.WriteString(u.Scheme)
buf.WriteByte(':')
}
if u.Opaque != "" {
buf.WriteString(u.Opaque)
} else {
if u.Scheme != "" || u.Host != "" || u.User != nil {
buf.WriteString("//")
if ui := u.User; ui != nil {
buf.WriteString(unescapeUserinfo(ui.String()))
buf.WriteByte('@')
}
if h := u.Host; h != "" {
buf.WriteString(h)
}
}
if u.Path != "" && u.Path[0] != '/' && u.Host != "" {
buf.WriteByte('/')
}
buf.WriteString(escape(u.Path, encodePath))
}
if u.RawQuery != "" {
buf.WriteByte('?')
buf.WriteString(u.RawQuery)
}
if u.Fragment != "" {
buf.WriteByte('#')
buf.WriteString(escape(u.Fragment, encodeFragment))
}
return buf.String()
}

View File

@ -1,3 +0,0 @@
.fuzz/
*.zip

View File

@ -1,10 +0,0 @@
language: go
go:
- 1.4.3
- 1.5.4
- 1.6.4
- 1.7.6
- 1.8.3
script: make check

View File

@ -1,26 +0,0 @@
Copyright (c) 2017 Ryan Armstrong. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,18 +0,0 @@
PACKAGE = github.com/cavaliercoder/go-cpio
all: check
check:
go test -v
cpio-fuzz.zip: *.go
go-fuzz-build $(PACKAGE)
fuzz: cpio-fuzz.zip
go-fuzz -bin=./cpio-fuzz.zip -workdir=.fuzz/
clean-fuzz:
rm -rf cpio-fuzz.zip .fuzz/crashers/* .fuzz/suppressions/*
.PHONY: all check

View File

@ -1,62 +0,0 @@
# go-cpio [![GoDoc](https://godoc.org/github.com/cavaliercoder/go-cpio?status.svg)](https://godoc.org/github.com/cavaliercoder/go-cpio) [![Build Status](https://travis-ci.org/cavaliercoder/go-cpio.svg?branch=master)](https://travis-ci.org/cavaliercoder/go-cpio) [![Go Report Card](https://goreportcard.com/badge/github.com/cavaliercoder/go-cpio)](https://goreportcard.com/report/github.com/cavaliercoder/go-cpio)
This package provides a Go native implementation of the CPIO archive file
format.
Currently, only the SVR4 (New ASCII) format is supported, both with and without
checksums.
```go
// Create a buffer to write our archive to.
buf := new(bytes.Buffer)
// Create a new cpio archive.
w := cpio.NewWriter(buf)
// Add some files to the archive.
var files = []struct {
Name, Body string
}{
{"readme.txt", "This archive contains some text files."},
{"gopher.txt", "Gopher names:\nGeorge\nGeoffrey\nGonzo"},
{"todo.txt", "Get animal handling license."},
}
for _, file := range files {
hdr := &cpio.Header{
Name: file.Name,
Mode: 0600,
Size: int64(len(file.Body)),
}
if err := w.WriteHeader(hdr); err != nil {
log.Fatalln(err)
}
if _, err := w.Write([]byte(file.Body)); err != nil {
log.Fatalln(err)
}
}
// Make sure to check the error on Close.
if err := w.Close(); err != nil {
log.Fatalln(err)
}
// Open the cpio archive for reading.
b := bytes.NewReader(buf.Bytes())
r := cpio.NewReader(b)
// Iterate through the files in the archive.
for {
hdr, err := r.Next()
if err == io.EOF {
// end of cpio archive
break
}
if err != nil {
log.Fatalln(err)
}
fmt.Printf("Contents of %s:\n", hdr.Name)
if _, err := io.Copy(os.Stdout, r); err != nil {
log.Fatalln(err)
}
fmt.Println()
}
```

View File

@ -1,8 +0,0 @@
/*
Package cpio implements access to CPIO archives. Currently, only the SVR4 (New
ASCII) format is supported, both with and without checksums.
References:
https://www.freebsd.org/cgi/man.cgi?query=cpio&sektion=5
*/
package cpio

View File

@ -1,75 +0,0 @@
package cpio
import (
"os"
"path"
"time"
)
// headerFileInfo implements os.FileInfo.
type headerFileInfo struct {
h *Header
}
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Base(path.Clean(fi.h.Name))
}
return path.Base(fi.h.Name)
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() interface{} { return fi.h }
func (fi headerFileInfo) Mode() (mode os.FileMode) {
// Set file permission bits.
mode = os.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&ModeSetuid != 0 {
// setuid
mode |= os.ModeSetuid
}
if fi.h.Mode&ModeSetgid != 0 {
// setgid
mode |= os.ModeSetgid
}
if fi.h.Mode&ModeSticky != 0 {
// sticky
mode |= os.ModeSticky
}
// Set file mode bits.
// clear perm, setuid, setgid and sticky bits.
m := os.FileMode(fi.h.Mode) & 0170000
if m == ModeDir {
// directory
mode |= os.ModeDir
}
if m == ModeNamedPipe {
// named pipe (FIFO)
mode |= os.ModeNamedPipe
}
if m == ModeSymlink {
// symbolic link
mode |= os.ModeSymlink
}
if m == ModeDevice {
// device file
mode |= os.ModeDevice
}
if m == ModeCharDevice {
// Unix character device
mode |= os.ModeDevice
mode |= os.ModeCharDevice
}
if m == ModeSocket {
// Unix domain socket
mode |= os.ModeSocket
}
return mode
}

View File

@ -1,35 +0,0 @@
// +build gofuzz
package cpio
import "bytes"
import "io"
// Fuzz tests the parsing and error handling of random byte arrays using
// https://github.com/dvyukov/go-fuzz.
func Fuzz(data []byte) int {
r := NewReader(bytes.NewReader(data))
h := NewHash()
for {
hdr, err := r.Next()
if err != nil {
if hdr != nil {
panic("hdr != nil on error")
}
if err == io.EOF {
// everything worked with random input... interesting
return 1
}
// error returned for random input. Good!
return -1
}
// hash file
h.Reset()
io.CopyN(h, r, hdr.Size)
h.Sum32()
// convert file header
FileInfoHeader(hdr.FileInfo())
}
}

View File

@ -1,45 +0,0 @@
package cpio
import (
"encoding/binary"
"hash"
)
type digest struct {
sum uint32
}
// NewHash returns a new hash.Hash32 computing the SVR4 checksum.
func NewHash() hash.Hash32 {
return &digest{}
}
func (d *digest) Write(p []byte) (n int, err error) {
for _, b := range p {
d.sum += uint32(b & 0xFF)
}
return len(p), nil
}
func (d *digest) Sum(b []byte) []byte {
out := [4]byte{}
binary.LittleEndian.PutUint32(out[:], d.sum)
return append(b, out[:]...)
}
func (d *digest) Sum32() uint32 {
return d.sum
}
func (d *digest) Reset() {
d.sum = 0
}
func (d *digest) Size() int {
return 4
}
func (d *digest) BlockSize() int {
return 1
}

View File

@ -1,153 +0,0 @@
package cpio
import (
"errors"
"fmt"
"os"
"time"
)
// Mode constants from the cpio spec.
const (
ModeSetuid = 04000 // Set uid
ModeSetgid = 02000 // Set gid
ModeSticky = 01000 // Save text (sticky bit)
ModeDir = 040000 // Directory
ModeNamedPipe = 010000 // FIFO
ModeRegular = 0100000 // Regular file
ModeSymlink = 0120000 // Symbolic link
ModeDevice = 060000 // Block special file
ModeCharDevice = 020000 // Character special file
ModeSocket = 0140000 // Socket
ModeType = 0170000 // Mask for the type bits
ModePerm = 0777 // Unix permission bits
)
const (
// headerEOF is the value of the filename of the last header in a CPIO archive.
headerEOF = "TRAILER!!!"
)
var (
ErrHeader = errors.New("cpio: invalid cpio header")
)
// A FileMode represents a file's mode and permission bits.
type FileMode int64
func (m FileMode) String() string {
return fmt.Sprintf("%#o", m)
}
// IsDir reports whether m describes a directory. That is, it tests for the
// ModeDir bit being set in m.
func (m FileMode) IsDir() bool {
return m&ModeDir != 0
}
// IsRegular reports whether m describes a regular file. That is, it tests for
// the ModeRegular bit being set in m.
func (m FileMode) IsRegular() bool {
return m&^ModePerm == ModeRegular
}
// Perm returns the Unix permission bits in m.
func (m FileMode) Perm() FileMode {
return m & ModePerm
}
// Checksum is the sum of all bytes in the file data. This sum is computed
// treating all bytes as unsigned values and using unsigned arithmetic. Only
// the least-significant 32 bits of the sum are stored. Use NewHash to compute
// the actual checksum of an archived file.
type Checksum uint32
func (c Checksum) String() string {
return fmt.Sprintf("%08X", uint32(c))
}
// A Header represents a single header in a CPIO archive.
type Header struct {
DeviceID int
Inode int64 // inode number
Mode FileMode // permission and mode bits
UID int // user id of the owner
GID int // group id of the owner
Links int // number of inbound links
ModTime time.Time // modified time
Size int64 // size in bytes
Name string // filename
Linkname string // target name of link
Checksum Checksum // computed checksum
pad int64 // bytes to pad before next header
}
// FileInfo returns an os.FileInfo for the Header.
func (h *Header) FileInfo() os.FileInfo {
return headerFileInfo{h}
}
// FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
// Because os.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("cpio: FileInfo is nil")
}
if sys, ok := fi.Sys().(*Header); ok {
// This FileInfo came from a Header (not the OS). Return a copy of the
// original Header.
h := &Header{}
*h = *sys
return h, nil
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
Mode: FileMode(fi.Mode().Perm()), // or'd with Mode* constants later
ModTime: fi.ModTime(),
Size: fi.Size(),
}
switch {
case fm.IsRegular():
h.Mode |= ModeRegular
case fi.IsDir():
h.Mode |= ModeDir
h.Name += "/"
h.Size = 0
case fm&os.ModeSymlink != 0:
h.Mode |= ModeSymlink
h.Linkname = link
case fm&os.ModeDevice != 0:
if fm&os.ModeCharDevice != 0 {
h.Mode |= ModeCharDevice
} else {
h.Mode |= ModeDevice
}
case fm&os.ModeNamedPipe != 0:
h.Mode |= ModeNamedPipe
case fm&os.ModeSocket != 0:
h.Mode |= ModeSocket
default:
return nil, fmt.Errorf("cpio: unknown file mode %v", fm)
}
if fm&os.ModeSetuid != 0 {
h.Mode |= ModeSetuid
}
if fm&os.ModeSetgid != 0 {
h.Mode |= ModeSetgid
}
if fm&os.ModeSticky != 0 {
h.Mode |= ModeSticky
}
return h, nil
}

View File

@ -1,72 +0,0 @@
package cpio
import (
"io"
"io/ioutil"
)
// A Reader provides sequential access to the contents of a CPIO archive. A CPIO
// archive consists of a sequence of files. The Next method advances to the next
// file in the archive (including the first), and then it can be treated as an
// io.Reader to access the file's data.
type Reader struct {
r io.Reader // underlying file reader
hdr *Header // current Header
eof int64 // bytes until the end of the current file
}
// NewReader creates a new Reader reading from r.
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
}
}
// Read reads from the current entry in the CPIO archive. It returns 0, io.EOF
// when it reaches the end of that entry, until Next is called to advance to the
// next entry.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.hdr == nil || r.eof == 0 {
return 0, io.EOF
}
rn := len(p)
if r.eof < int64(rn) {
rn = int(r.eof)
}
n, err = r.r.Read(p[0:rn])
r.eof -= int64(n)
return
}
// Next advances to the next entry in the CPIO archive.
// io.EOF is returned at the end of the input.
func (r *Reader) Next() (*Header, error) {
if r.hdr == nil {
return r.next()
}
skp := r.eof + r.hdr.pad
if skp > 0 {
_, err := io.CopyN(ioutil.Discard, r.r, skp)
if err != nil {
return nil, err
}
}
return r.next()
}
func (r *Reader) next() (*Header, error) {
r.eof = 0
hdr, err := readHeader(r.r)
if err != nil {
return nil, err
}
r.hdr = hdr
r.eof = hdr.Size
return hdr, nil
}
// ReadHeader creates a new Header, reading from r.
func readHeader(r io.Reader) (*Header, error) {
// currently only SVR4 format is supported
return readSVR4Header(r)
}

View File

@ -1,152 +0,0 @@
package cpio
import (
"bytes"
"fmt"
"io"
"strconv"
"time"
)
const (
svr4MaxNameSize = 4096 // MAX_PATH
svr4MaxFileSize = 4294967295
)
var svr4Magic = []byte{0x30, 0x37, 0x30, 0x37, 0x30, 0x31} // 070701
func readHex(s string) int64 {
// errors are ignored and 0 returned
i, _ := strconv.ParseInt(s, 16, 64)
return i
}
func writeHex(b []byte, i int64) {
// i needs to be in range of uint32
copy(b, fmt.Sprintf("%08X", i))
}
func readSVR4Header(r io.Reader) (*Header, error) {
var buf [110]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return nil, err
}
// TODO: check endianness
// check magic
hasCRC := false
if !bytes.HasPrefix(buf[:], svr4Magic[:5]) {
return nil, ErrHeader
}
if buf[5] == 0x32 { // '2'
hasCRC = true
} else if buf[5] != 0x31 { // '1'
return nil, ErrHeader
}
asc := string(buf[:])
hdr := &Header{}
hdr.Inode = readHex(asc[6:14])
hdr.Mode = FileMode(readHex(asc[14:22]))
hdr.UID = int(readHex(asc[22:30]))
hdr.GID = int(readHex(asc[30:38]))
hdr.Links = int(readHex(asc[38:46]))
hdr.ModTime = time.Unix(readHex(asc[46:54]), 0)
hdr.Size = readHex(asc[54:62])
if hdr.Size > svr4MaxFileSize {
return nil, ErrHeader
}
nameSize := readHex(asc[94:102])
if nameSize < 1 || nameSize > svr4MaxNameSize {
return nil, ErrHeader
}
hdr.Checksum = Checksum(readHex(asc[102:110]))
if !hasCRC && hdr.Checksum != 0 {
return nil, ErrHeader
}
name := make([]byte, nameSize)
if _, err := io.ReadFull(r, name); err != nil {
return nil, err
}
hdr.Name = string(name[:nameSize-1])
if hdr.Name == headerEOF {
return nil, io.EOF
}
// store padding between end of file and next header
hdr.pad = (4 - (hdr.Size % 4)) % 4
// skip to end of header/start of file
pad := (4 - (len(buf)+len(name))%4) % 4
if pad > 0 {
if _, err := io.ReadFull(r, buf[:pad]); err != nil {
return nil, err
}
}
// read link name
if hdr.Mode&^ModePerm == ModeSymlink {
if hdr.Size < 1 || hdr.Size > svr4MaxNameSize {
return nil, ErrHeader
}
b := make([]byte, hdr.Size)
if _, err := io.ReadFull(r, b); err != nil {
return nil, err
}
hdr.Linkname = string(b)
hdr.Size = 0
}
return hdr, nil
}
func writeSVR4Header(w io.Writer, hdr *Header) (pad int64, err error) {
var hdrBuf [110]byte
for i := 0; i < len(hdrBuf); i++ {
hdrBuf[i] = '0'
}
magic := svr4Magic
if hdr.Checksum != 0 {
magic[5] = 0x32
}
copy(hdrBuf[:], magic)
writeHex(hdrBuf[6:14], hdr.Inode)
writeHex(hdrBuf[14:22], int64(hdr.Mode))
writeHex(hdrBuf[22:30], int64(hdr.UID))
writeHex(hdrBuf[30:38], int64(hdr.GID))
writeHex(hdrBuf[38:46], int64(hdr.Links))
if !hdr.ModTime.IsZero() {
writeHex(hdrBuf[46:54], hdr.ModTime.Unix())
}
writeHex(hdrBuf[54:62], hdr.Size)
writeHex(hdrBuf[94:102], int64(len(hdr.Name)+1))
if hdr.Checksum != 0 {
writeHex(hdrBuf[102:110], int64(hdr.Checksum))
}
// write header
_, err = w.Write(hdrBuf[:])
if err != nil {
return
}
// write filename
_, err = io.WriteString(w, hdr.Name+"\x00")
if err != nil {
return
}
// pad to end of filename
npad := (4 - ((len(hdrBuf) + len(hdr.Name) + 1) % 4)) % 4
_, err = w.Write(zeroBlock[:npad])
if err != nil {
return
}
// compute padding to end of file
pad = (4 - (hdr.Size % 4)) % 4
return
}

View File

@ -1,128 +0,0 @@
package cpio
import (
"errors"
"fmt"
"io"
)
var (
ErrWriteTooLong = errors.New("cpio: write too long")
ErrWriteAfterClose = errors.New("cpio: write after close")
)
var trailer = &Header{
Name: string(headerEOF),
Links: 1,
}
var zeroBlock [4]byte
// A Writer provides sequential writing of a CPIO archive. A CPIO archive
// consists of a sequence of files. Call WriteHeader to begin a new file, and
// then call Write to supply that file's data, writing at most hdr.Size bytes in
// total.
type Writer struct {
w io.Writer
nb int64 // number of unwritten bytes for current file entry
pad int64 // amount of padding to write after current file entry
inode int64
err error
closed bool
}
// NewWriter creates a new Writer writing to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{w: w}
}
// Flush finishes writing the current file (optional).
func (w *Writer) Flush() error {
if w.nb > 0 {
w.err = fmt.Errorf("cpio: missed writing %d bytes", w.nb)
return w.err
}
_, w.err = w.w.Write(zeroBlock[:w.pad])
if w.err != nil {
return w.err
}
w.nb = 0
w.pad = 0
return w.err
}
// WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header. Calling after a Close
// will return ErrWriteAfterClose.
func (w *Writer) WriteHeader(hdr *Header) (err error) {
if w.closed {
return ErrWriteAfterClose
}
if w.err == nil {
w.Flush()
}
if w.err != nil {
return w.err
}
if hdr.Name != headerEOF {
// TODO: should we be mutating hdr here?
// ensure all inodes are unique
w.inode++
if hdr.Inode == 0 {
hdr.Inode = w.inode
}
// ensure file type is set
if hdr.Mode&^ModePerm == 0 {
hdr.Mode |= ModeRegular
}
// ensure regular files have at least 1 inbound link
if hdr.Links < 1 && hdr.Mode.IsRegular() {
hdr.Links = 1
}
}
w.nb = hdr.Size
w.pad, w.err = writeSVR4Header(w.w, hdr)
return
}
// Write writes to the current entry in the CPIO archive. Write returns the
// error ErrWriteTooLong if more than hdr.Size bytes are written after
// WriteHeader.
func (w *Writer) Write(p []byte) (n int, err error) {
if w.closed {
err = ErrWriteAfterClose
return
}
overwrite := false
if int64(len(p)) > w.nb {
p = p[0:w.nb]
overwrite = true
}
n, err = w.w.Write(p)
w.nb -= int64(n)
if err == nil && overwrite {
err = ErrWriteTooLong
return
}
w.err = err
return
}
// Close closes the CPIO archive, flushing any unwritten data to the underlying
// writer.
func (w *Writer) Close() error {
if w.err != nil || w.closed {
return w.err
}
w.err = w.WriteHeader(trailer)
if w.err != nil {
return w.err
}
w.Flush()
w.closed = true
return w.err
}

View File

@ -1,24 +0,0 @@
Copyright (c) 2014 CloudFlare Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,94 +0,0 @@
// Package auth implements an interface for providing CFSSL
// authentication. This is meant to authenticate a client CFSSL to a
// remote CFSSL in order to prevent unauthorised use of the signature
// capabilities. This package provides both the interface and a
// standard HMAC-based implementation.
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io/ioutil"
"os"
"strings"
)
// An AuthenticatedRequest contains a request and authentication
// token. The Provider may determine whether to validate the timestamp
// and remote address.
type AuthenticatedRequest struct {
// An Authenticator decides whether to use this field.
Timestamp int64 `json:"timestamp,omitempty"`
RemoteAddress []byte `json:"remote_address,omitempty"`
Token []byte `json:"token"`
Request []byte `json:"request"`
}
// A Provider can generate tokens from a request and verify a
// request. The handling of additional authentication data (such as
// the IP address) is handled by the concrete type, as is any
// serialisation and state-keeping.
type Provider interface {
Token(req []byte) (token []byte, err error)
Verify(aReq *AuthenticatedRequest) bool
}
// Standard implements an HMAC-SHA-256 authentication provider. It may
// be supplied additional data at creation time that will be used as
// request || additional-data with the HMAC.
type Standard struct {
key []byte
ad []byte
}
// New generates a new standard authentication provider from the key
// and additional data. The additional data will be used when
// generating a new token.
func New(key string, ad []byte) (*Standard, error) {
if splitKey := strings.SplitN(key, ":", 2); len(splitKey) == 2 {
switch splitKey[0] {
case "env":
key = os.Getenv(splitKey[1])
case "file":
data, err := ioutil.ReadFile(splitKey[1])
if err != nil {
return nil, err
}
key = strings.TrimSpace(string(data))
default:
return nil, fmt.Errorf("unknown key prefix: %s", splitKey[0])
}
}
keyBytes, err := hex.DecodeString(key)
if err != nil {
return nil, err
}
return &Standard{keyBytes, ad}, nil
}
// Token generates a new authentication token from the request.
func (p Standard) Token(req []byte) (token []byte, err error) {
h := hmac.New(sha256.New, p.key)
h.Write(req)
h.Write(p.ad)
return h.Sum(nil), nil
}
// Verify determines whether an authenticated request is valid.
func (p Standard) Verify(ad *AuthenticatedRequest) bool {
if ad == nil {
return false
}
// Standard token generation returns no error.
token, _ := p.Token(ad.Request)
if len(ad.Token) != len(token) {
return false
}
return hmac.Equal(token, ad.Token)
}

View File

@ -1,75 +0,0 @@
# certdb usage
Using a database enables additional functionality for existing commands when a
db config is provided:
- `sign` and `gencert` add a certificate to the certdb after signing it
- `serve` enables database functionality for the sign and revoke endpoints
A database is required for the following:
- `revoke` marks certificates revoked in the database with an optional reason
- `ocsprefresh` refreshes the table of cached OCSP responses
- `ocspdump` outputs cached OCSP responses in a concatenated base64-encoded format
## Setup/Migration
This directory stores [goose](https://bitbucket.org/liamstask/goose/) db migration scripts for various DB backends.
Currently supported:
- MySQL in mysql
- PostgreSQL in pg
- SQLite in sqlite
### Get goose
go get bitbucket.org/liamstask/goose/cmd/goose
### Use goose to start and terminate a MySQL DB
To start a MySQL using goose:
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/mysql up
To tear down a MySQL DB using goose
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/mysql down
Note: the administration of MySQL DB is not included. We assume
the databases being connected to are already created and access control
is properly handled.
### Use goose to start and terminate a PostgreSQL DB
To start a PostgreSQL using goose:
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/pg up
To tear down a PostgreSQL DB using goose
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/pg down
Note: the administration of PostgreSQL DB is not included. We assume
the databases being connected to are already created and access control
is properly handled.
### Use goose to start and terminate a SQLite DB
To start a SQLite DB using goose:
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/sqlite up
To tear down a SQLite DB using goose
goose -path $GOPATH/src/github.com/cloudflare/cfssl/certdb/sqlite down
## CFSSL Configuration
Several cfssl commands take a -db-config flag. Create a file with a
JSON dictionary:
{"driver":"sqlite3","data_source":"certs.db"}
or
{"driver":"postgres","data_source":"postgres://user:password@host/db"}
or
{"driver":"mysql","data_source":"user:password@tcp(hostname:3306)/db?parseTime=true"}

View File

@ -1,42 +0,0 @@
package certdb
import (
"time"
)
// CertificateRecord encodes a certificate and its metadata
// that will be recorded in a database.
type CertificateRecord struct {
Serial string `db:"serial_number"`
AKI string `db:"authority_key_identifier"`
CALabel string `db:"ca_label"`
Status string `db:"status"`
Reason int `db:"reason"`
Expiry time.Time `db:"expiry"`
RevokedAt time.Time `db:"revoked_at"`
PEM string `db:"pem"`
}
// OCSPRecord encodes a OCSP response body and its metadata
// that will be recorded in a database.
type OCSPRecord struct {
Serial string `db:"serial_number"`
AKI string `db:"authority_key_identifier"`
Body string `db:"body"`
Expiry time.Time `db:"expiry"`
}
// Accessor abstracts the CRUD of certdb objects from a DB.
type Accessor interface {
InsertCertificate(cr CertificateRecord) error
GetCertificate(serial, aki string) ([]CertificateRecord, error)
GetUnexpiredCertificates() ([]CertificateRecord, error)
GetRevokedAndUnexpiredCertificates() ([]CertificateRecord, error)
GetRevokedAndUnexpiredCertificatesByLabel(label string) ([]CertificateRecord, error)
RevokeCertificate(serial, aki string, reasonCode int) error
InsertOCSP(rr OCSPRecord) error
GetOCSP(serial, aki string) ([]OCSPRecord, error)
GetUnexpiredOCSPs() ([]OCSPRecord, error)
UpdateOCSP(serial, aki, body string, expiry time.Time) error
UpsertOCSP(serial, aki, body string, expiry time.Time) error
}

View File

@ -1,659 +0,0 @@
// Package config contains the configuration logic for CFSSL.
package config
import (
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"regexp"
"strconv"
"strings"
"time"
"github.com/cloudflare/cfssl/auth"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/helpers"
"github.com/cloudflare/cfssl/log"
ocspConfig "github.com/cloudflare/cfssl/ocsp/config"
)
// A CSRWhitelist stores booleans for fields in the CSR. If a CSRWhitelist is
// not present in a SigningProfile, all of these fields may be copied from the
// CSR into the signed certificate. If a CSRWhitelist *is* present in a
// SigningProfile, only those fields with a `true` value in the CSRWhitelist may
// be copied from the CSR to the signed certificate. Note that some of these
// fields, like Subject, can be provided or partially provided through the API.
// Since API clients are expected to be trusted, but CSRs are not, fields
// provided through the API are not subject to whitelisting through this
// mechanism.
type CSRWhitelist struct {
Subject, PublicKeyAlgorithm, PublicKey, SignatureAlgorithm bool
DNSNames, IPAddresses, EmailAddresses, URIs bool
}
// OID is our own version of asn1's ObjectIdentifier, so we can define a custom
// JSON marshal / unmarshal.
type OID asn1.ObjectIdentifier
// CertificatePolicy represents the ASN.1 PolicyInformation structure from
// https://tools.ietf.org/html/rfc3280.html#page-106.
// Valid values of Type are "id-qt-unotice" and "id-qt-cps"
type CertificatePolicy struct {
ID OID
Qualifiers []CertificatePolicyQualifier
}
// CertificatePolicyQualifier represents a single qualifier from an ASN.1
// PolicyInformation structure.
type CertificatePolicyQualifier struct {
Type string
Value string
}
// AuthRemote is an authenticated remote signer.
type AuthRemote struct {
RemoteName string `json:"remote"`
AuthKeyName string `json:"auth_key"`
}
// CAConstraint specifies various CA constraints on the signed certificate.
// CAConstraint would verify against (and override) the CA
// extensions in the given CSR.
type CAConstraint struct {
IsCA bool `json:"is_ca"`
MaxPathLen int `json:"max_path_len"`
MaxPathLenZero bool `json:"max_path_len_zero"`
}
// A SigningProfile stores information that the CA needs to store
// signature policy.
type SigningProfile struct {
Usage []string `json:"usages"`
IssuerURL []string `json:"issuer_urls"`
OCSP string `json:"ocsp_url"`
CRL string `json:"crl_url"`
CAConstraint CAConstraint `json:"ca_constraint"`
OCSPNoCheck bool `json:"ocsp_no_check"`
ExpiryString string `json:"expiry"`
BackdateString string `json:"backdate"`
AuthKeyName string `json:"auth_key"`
RemoteName string `json:"remote"`
NotBefore time.Time `json:"not_before"`
NotAfter time.Time `json:"not_after"`
NameWhitelistString string `json:"name_whitelist"`
AuthRemote AuthRemote `json:"auth_remote"`
CTLogServers []string `json:"ct_log_servers"`
AllowedExtensions []OID `json:"allowed_extensions"`
CertStore string `json:"cert_store"`
Policies []CertificatePolicy
Expiry time.Duration
Backdate time.Duration
Provider auth.Provider
RemoteProvider auth.Provider
RemoteServer string
RemoteCAs *x509.CertPool
ClientCert *tls.Certificate
CSRWhitelist *CSRWhitelist
NameWhitelist *regexp.Regexp
ExtensionWhitelist map[string]bool
ClientProvidesSerialNumbers bool
}
// UnmarshalJSON unmarshals a JSON string into an OID.
func (oid *OID) UnmarshalJSON(data []byte) (err error) {
if data[0] != '"' || data[len(data)-1] != '"' {
return errors.New("OID JSON string not wrapped in quotes." + string(data))
}
data = data[1 : len(data)-1]
parsedOid, err := parseObjectIdentifier(string(data))
if err != nil {
return err
}
*oid = OID(parsedOid)
return
}
// MarshalJSON marshals an oid into a JSON string.
func (oid OID) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%v"`, asn1.ObjectIdentifier(oid))), nil
}
func parseObjectIdentifier(oidString string) (oid asn1.ObjectIdentifier, err error) {
validOID, err := regexp.MatchString("\\d(\\.\\d+)*", oidString)
if err != nil {
return
}
if !validOID {
err = errors.New("Invalid OID")
return
}
segments := strings.Split(oidString, ".")
oid = make(asn1.ObjectIdentifier, len(segments))
for i, intString := range segments {
oid[i], err = strconv.Atoi(intString)
if err != nil {
return
}
}
return
}
const timeFormat = "2006-01-02T15:04:05"
// populate is used to fill in the fields that are not in JSON
//
// First, the ExpiryString parameter is needed to parse
// expiration timestamps from JSON. The JSON decoder is not able to
// decode a string time duration to a time.Duration, so this is called
// when loading the configuration to properly parse and fill out the
// Expiry parameter.
// This function is also used to create references to the auth key
// and default remote for the profile.
// It returns true if ExpiryString is a valid representation of a
// time.Duration, and the AuthKeyString and RemoteName point to
// valid objects. It returns false otherwise.
func (p *SigningProfile) populate(cfg *Config) error {
if p == nil {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, errors.New("can't parse nil profile"))
}
var err error
if p.RemoteName == "" && p.AuthRemote.RemoteName == "" {
log.Debugf("parse expiry in profile")
if p.ExpiryString == "" {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, errors.New("empty expiry string"))
}
dur, err := time.ParseDuration(p.ExpiryString)
if err != nil {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, err)
}
log.Debugf("expiry is valid")
p.Expiry = dur
if p.BackdateString != "" {
dur, err = time.ParseDuration(p.BackdateString)
if err != nil {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, err)
}
p.Backdate = dur
}
if !p.NotBefore.IsZero() && !p.NotAfter.IsZero() && p.NotAfter.Before(p.NotBefore) {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, err)
}
if len(p.Policies) > 0 {
for _, policy := range p.Policies {
for _, qualifier := range policy.Qualifiers {
if qualifier.Type != "" && qualifier.Type != "id-qt-unotice" && qualifier.Type != "id-qt-cps" {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("invalid policy qualifier type"))
}
}
}
}
} else if p.RemoteName != "" {
log.Debug("match remote in profile to remotes section")
if p.AuthRemote.RemoteName != "" {
log.Error("profile has both a remote and an auth remote specified")
return cferr.New(cferr.PolicyError, cferr.InvalidPolicy)
}
if remote := cfg.Remotes[p.RemoteName]; remote != "" {
if err := p.updateRemote(remote); err != nil {
return err
}
} else {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to find remote in remotes section"))
}
} else {
log.Debug("match auth remote in profile to remotes section")
if remote := cfg.Remotes[p.AuthRemote.RemoteName]; remote != "" {
if err := p.updateRemote(remote); err != nil {
return err
}
} else {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to find remote in remotes section"))
}
}
if p.AuthKeyName != "" {
log.Debug("match auth key in profile to auth_keys section")
if key, ok := cfg.AuthKeys[p.AuthKeyName]; ok == true {
if key.Type == "standard" {
p.Provider, err = auth.New(key.Key, nil)
if err != nil {
log.Debugf("failed to create new standard auth provider: %v", err)
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to create new standard auth provider"))
}
} else {
log.Debugf("unknown authentication type %v", key.Type)
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("unknown authentication type"))
}
} else {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to find auth_key in auth_keys section"))
}
}
if p.AuthRemote.AuthKeyName != "" {
log.Debug("match auth remote key in profile to auth_keys section")
if key, ok := cfg.AuthKeys[p.AuthRemote.AuthKeyName]; ok == true {
if key.Type == "standard" {
p.RemoteProvider, err = auth.New(key.Key, nil)
if err != nil {
log.Debugf("failed to create new standard auth provider: %v", err)
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to create new standard auth provider"))
}
} else {
log.Debugf("unknown authentication type %v", key.Type)
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("unknown authentication type"))
}
} else {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to find auth_remote's auth_key in auth_keys section"))
}
}
if p.NameWhitelistString != "" {
log.Debug("compiling whitelist regular expression")
rule, err := regexp.Compile(p.NameWhitelistString)
if err != nil {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to compile name whitelist section"))
}
p.NameWhitelist = rule
}
p.ExtensionWhitelist = map[string]bool{}
for _, oid := range p.AllowedExtensions {
p.ExtensionWhitelist[asn1.ObjectIdentifier(oid).String()] = true
}
return nil
}
// updateRemote takes a signing profile and initializes the remote server object
// to the hostname:port combination sent by remote.
func (p *SigningProfile) updateRemote(remote string) error {
if remote != "" {
p.RemoteServer = remote
}
return nil
}
// OverrideRemotes takes a signing configuration and updates the remote server object
// to the hostname:port combination sent by remote
func (p *Signing) OverrideRemotes(remote string) error {
if remote != "" {
var err error
for _, profile := range p.Profiles {
err = profile.updateRemote(remote)
if err != nil {
return err
}
}
err = p.Default.updateRemote(remote)
if err != nil {
return err
}
}
return nil
}
// SetClientCertKeyPairFromFile updates the properties to set client certificates for mutual
// authenticated TLS remote requests
func (p *Signing) SetClientCertKeyPairFromFile(certFile string, keyFile string) error {
if certFile != "" && keyFile != "" {
cert, err := helpers.LoadClientCertificate(certFile, keyFile)
if err != nil {
return err
}
for _, profile := range p.Profiles {
profile.ClientCert = cert
}
p.Default.ClientCert = cert
}
return nil
}
// SetRemoteCAsFromFile reads root CAs from file and updates the properties to set remote CAs for TLS
// remote requests
func (p *Signing) SetRemoteCAsFromFile(caFile string) error {
if caFile != "" {
remoteCAs, err := helpers.LoadPEMCertPool(caFile)
if err != nil {
return err
}
p.SetRemoteCAs(remoteCAs)
}
return nil
}
// SetRemoteCAs updates the properties to set remote CAs for TLS
// remote requests
func (p *Signing) SetRemoteCAs(remoteCAs *x509.CertPool) {
for _, profile := range p.Profiles {
profile.RemoteCAs = remoteCAs
}
p.Default.RemoteCAs = remoteCAs
}
// NeedsRemoteSigner returns true if one of the profiles has a remote set
func (p *Signing) NeedsRemoteSigner() bool {
for _, profile := range p.Profiles {
if profile.RemoteServer != "" {
return true
}
}
if p.Default.RemoteServer != "" {
return true
}
return false
}
// NeedsLocalSigner returns true if one of the profiles doe not have a remote set
func (p *Signing) NeedsLocalSigner() bool {
for _, profile := range p.Profiles {
if profile.RemoteServer == "" {
return true
}
}
if p.Default.RemoteServer == "" {
return true
}
return false
}
// Usages parses the list of key uses in the profile, translating them
// to a list of X.509 key usages and extended key usages. The unknown
// uses are collected into a slice that is also returned.
func (p *SigningProfile) Usages() (ku x509.KeyUsage, eku []x509.ExtKeyUsage, unk []string) {
for _, keyUse := range p.Usage {
if kuse, ok := KeyUsage[keyUse]; ok {
ku |= kuse
} else if ekuse, ok := ExtKeyUsage[keyUse]; ok {
eku = append(eku, ekuse)
} else {
unk = append(unk, keyUse)
}
}
return
}
// A valid profile must be a valid local profile or a valid remote profile.
// A valid local profile has defined at least key usages to be used, and a
// valid local default profile has defined at least a default expiration.
// A valid remote profile (default or not) has remote signer initialized.
// In addition, a remote profile must has a valid auth provider if auth
// key defined.
func (p *SigningProfile) validProfile(isDefault bool) bool {
if p == nil {
return false
}
if p.AuthRemote.RemoteName == "" && p.AuthRemote.AuthKeyName != "" {
log.Debugf("invalid auth remote profile: no remote signer specified")
return false
}
if p.RemoteName != "" {
log.Debugf("validate remote profile")
if p.RemoteServer == "" {
log.Debugf("invalid remote profile: no remote signer specified")
return false
}
if p.AuthKeyName != "" && p.Provider == nil {
log.Debugf("invalid remote profile: auth key name is defined but no auth provider is set")
return false
}
if p.AuthRemote.RemoteName != "" {
log.Debugf("invalid remote profile: auth remote is also specified")
return false
}
} else if p.AuthRemote.RemoteName != "" {
log.Debugf("validate auth remote profile")
if p.RemoteServer == "" {
log.Debugf("invalid auth remote profile: no remote signer specified")
return false
}
if p.AuthRemote.AuthKeyName == "" || p.RemoteProvider == nil {
log.Debugf("invalid auth remote profile: no auth key is defined")
return false
}
} else {
log.Debugf("validate local profile")
if !isDefault {
if len(p.Usage) == 0 {
log.Debugf("invalid local profile: no usages specified")
return false
} else if _, _, unk := p.Usages(); len(unk) == len(p.Usage) {
log.Debugf("invalid local profile: no valid usages")
return false
}
} else {
if p.Expiry == 0 {
log.Debugf("invalid local profile: no expiry set")
return false
}
}
}
log.Debugf("profile is valid")
return true
}
// This checks if the SigningProfile object contains configurations that are only effective with a local signer
// which has access to CA private key.
func (p *SigningProfile) hasLocalConfig() bool {
if p.Usage != nil ||
p.IssuerURL != nil ||
p.OCSP != "" ||
p.ExpiryString != "" ||
p.BackdateString != "" ||
p.CAConstraint.IsCA != false ||
!p.NotBefore.IsZero() ||
!p.NotAfter.IsZero() ||
p.NameWhitelistString != "" ||
len(p.CTLogServers) != 0 {
return true
}
return false
}
// warnSkippedSettings prints a log warning message about skipped settings
// in a SigningProfile, usually due to remote signer.
func (p *Signing) warnSkippedSettings() {
const warningMessage = `The configuration value by "usages", "issuer_urls", "ocsp_url", "crl_url", "ca_constraint", "expiry", "backdate", "not_before", "not_after", "cert_store" and "ct_log_servers" are skipped`
if p == nil {
return
}
if (p.Default.RemoteName != "" || p.Default.AuthRemote.RemoteName != "") && p.Default.hasLocalConfig() {
log.Warning("default profile points to a remote signer: ", warningMessage)
}
for name, profile := range p.Profiles {
if (profile.RemoteName != "" || profile.AuthRemote.RemoteName != "") && profile.hasLocalConfig() {
log.Warningf("Profiles[%s] points to a remote signer: %s", name, warningMessage)
}
}
}
// Signing codifies the signature configuration policy for a CA.
type Signing struct {
Profiles map[string]*SigningProfile `json:"profiles"`
Default *SigningProfile `json:"default"`
}
// Config stores configuration information for the CA.
type Config struct {
Signing *Signing `json:"signing"`
OCSP *ocspConfig.Config `json:"ocsp"`
AuthKeys map[string]AuthKey `json:"auth_keys,omitempty"`
Remotes map[string]string `json:"remotes,omitempty"`
}
// Valid ensures that Config is a valid configuration. It should be
// called immediately after parsing a configuration file.
func (c *Config) Valid() bool {
return c.Signing.Valid()
}
// Valid checks the signature policies, ensuring they are valid
// policies. A policy is valid if it has defined at least key usages
// to be used, and a valid default profile has defined at least a
// default expiration.
func (p *Signing) Valid() bool {
if p == nil {
return false
}
log.Debugf("validating configuration")
if !p.Default.validProfile(true) {
log.Debugf("default profile is invalid")
return false
}
for _, sp := range p.Profiles {
if !sp.validProfile(false) {
log.Debugf("invalid profile")
return false
}
}
p.warnSkippedSettings()
return true
}
// KeyUsage contains a mapping of string names to key usages.
var KeyUsage = map[string]x509.KeyUsage{
"signing": x509.KeyUsageDigitalSignature,
"digital signature": x509.KeyUsageDigitalSignature,
"content commitment": x509.KeyUsageContentCommitment,
"key encipherment": x509.KeyUsageKeyEncipherment,
"key agreement": x509.KeyUsageKeyAgreement,
"data encipherment": x509.KeyUsageDataEncipherment,
"cert sign": x509.KeyUsageCertSign,
"crl sign": x509.KeyUsageCRLSign,
"encipher only": x509.KeyUsageEncipherOnly,
"decipher only": x509.KeyUsageDecipherOnly,
}
// ExtKeyUsage contains a mapping of string names to extended key
// usages.
var ExtKeyUsage = map[string]x509.ExtKeyUsage{
"any": x509.ExtKeyUsageAny,
"server auth": x509.ExtKeyUsageServerAuth,
"client auth": x509.ExtKeyUsageClientAuth,
"code signing": x509.ExtKeyUsageCodeSigning,
"email protection": x509.ExtKeyUsageEmailProtection,
"s/mime": x509.ExtKeyUsageEmailProtection,
"ipsec end system": x509.ExtKeyUsageIPSECEndSystem,
"ipsec tunnel": x509.ExtKeyUsageIPSECTunnel,
"ipsec user": x509.ExtKeyUsageIPSECUser,
"timestamping": x509.ExtKeyUsageTimeStamping,
"ocsp signing": x509.ExtKeyUsageOCSPSigning,
"microsoft sgc": x509.ExtKeyUsageMicrosoftServerGatedCrypto,
"netscape sgc": x509.ExtKeyUsageNetscapeServerGatedCrypto,
}
// An AuthKey contains an entry for a key used for authentication.
type AuthKey struct {
// Type contains information needed to select the appropriate
// constructor. For example, "standard" for HMAC-SHA-256,
// "standard-ip" for HMAC-SHA-256 incorporating the client's
// IP.
Type string `json:"type"`
// Key contains the key information, such as a hex-encoded
// HMAC key.
Key string `json:"key"`
}
// DefaultConfig returns a default configuration specifying basic key
// usage and a 1 year expiration time. The key usages chosen are
// signing, key encipherment, client auth and server auth.
func DefaultConfig() *SigningProfile {
d := helpers.OneYear
return &SigningProfile{
Usage: []string{"signing", "key encipherment", "server auth", "client auth"},
Expiry: d,
ExpiryString: "8760h",
}
}
// LoadFile attempts to load the configuration file stored at the path
// and returns the configuration. On error, it returns nil.
func LoadFile(path string) (*Config, error) {
log.Debugf("loading configuration file from %s", path)
if path == "" {
return nil, cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, errors.New("invalid path"))
}
body, err := ioutil.ReadFile(path)
if err != nil {
return nil, cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, errors.New("could not read configuration file"))
}
return LoadConfig(body)
}
// LoadConfig attempts to load the configuration from a byte slice.
// On error, it returns nil.
func LoadConfig(config []byte) (*Config, error) {
var cfg = &Config{}
err := json.Unmarshal(config, &cfg)
if err != nil {
return nil, cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy,
errors.New("failed to unmarshal configuration: "+err.Error()))
}
if cfg.Signing == nil {
return nil, errors.New("No \"signing\" field present")
}
if cfg.Signing.Default == nil {
log.Debugf("no default given: using default config")
cfg.Signing.Default = DefaultConfig()
} else {
if err := cfg.Signing.Default.populate(cfg); err != nil {
return nil, err
}
}
for k := range cfg.Signing.Profiles {
if err := cfg.Signing.Profiles[k].populate(cfg); err != nil {
return nil, err
}
}
if !cfg.Valid() {
return nil, cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, errors.New("invalid configuration"))
}
log.Debugf("configuration ok")
return cfg, nil
}

View File

@ -1,188 +0,0 @@
// Package pkcs7 implements the subset of the CMS PKCS #7 datatype that is typically
// used to package certificates and CRLs. Using openssl, every certificate converted
// to PKCS #7 format from another encoding such as PEM conforms to this implementation.
// reference: https://www.openssl.org/docs/man1.1.0/apps/crl2pkcs7.html
//
// PKCS #7 Data type, reference: https://tools.ietf.org/html/rfc2315
//
// The full pkcs#7 cryptographic message syntax allows for cryptographic enhancements,
// for example data can be encrypted and signed and then packaged through pkcs#7 to be
// sent over a network and then verified and decrypted. It is asn1, and the type of
// PKCS #7 ContentInfo, which comprises the PKCS #7 structure, is:
//
// ContentInfo ::= SEQUENCE {
// contentType ContentType,
// content [0] EXPLICIT ANY DEFINED BY contentType OPTIONAL
// }
//
// There are 6 possible ContentTypes, data, signedData, envelopedData,
// signedAndEnvelopedData, digestedData, and encryptedData. Here signedData, Data, and encrypted
// Data are implemented, as the degenerate case of signedData without a signature is the typical
// format for transferring certificates and CRLS, and Data and encryptedData are used in PKCS #12
// formats.
// The ContentType signedData has the form:
//
//
// signedData ::= SEQUENCE {
// version Version,
// digestAlgorithms DigestAlgorithmIdentifiers,
// contentInfo ContentInfo,
// certificates [0] IMPLICIT ExtendedCertificatesAndCertificates OPTIONAL
// crls [1] IMPLICIT CertificateRevocationLists OPTIONAL,
// signerInfos SignerInfos
// }
//
// As of yet signerInfos and digestAlgorithms are not parsed, as they are not relevant to
// this system's use of PKCS #7 data. Version is an integer type, note that PKCS #7 is
// recursive, this second layer of ContentInfo is similar ignored for our degenerate
// usage. The ExtendedCertificatesAndCertificates type consists of a sequence of choices
// between PKCS #6 extended certificates and x509 certificates. Any sequence consisting
// of any number of extended certificates is not yet supported in this implementation.
//
// The ContentType Data is simply a raw octet string and is parsed directly into a Go []byte slice.
//
// The ContentType encryptedData is the most complicated and its form can be gathered by
// the go type below. It essentially contains a raw octet string of encrypted data and an
// algorithm identifier for use in decrypting this data.
package pkcs7
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
cferr "github.com/cloudflare/cfssl/errors"
)
// Types used for asn1 Unmarshaling.
type signedData struct {
Version int
DigestAlgorithms asn1.RawValue
ContentInfo asn1.RawValue
Certificates asn1.RawValue `asn1:"optional" asn1:"tag:0"`
Crls asn1.RawValue `asn1:"optional"`
SignerInfos asn1.RawValue
}
type initPKCS7 struct {
Raw asn1.RawContent
ContentType asn1.ObjectIdentifier
Content asn1.RawValue `asn1:"tag:0,explicit,optional"`
}
// Object identifier strings of the three implemented PKCS7 types.
const (
ObjIDData = "1.2.840.113549.1.7.1"
ObjIDSignedData = "1.2.840.113549.1.7.2"
ObjIDEncryptedData = "1.2.840.113549.1.7.6"
)
// PKCS7 represents the ASN1 PKCS #7 Content type. It contains one of three
// possible types of Content objects, as denoted by the object identifier in
// the ContentInfo field, the other two being nil. SignedData
// is the degenerate SignedData Content info without signature used
// to hold certificates and crls. Data is raw bytes, and EncryptedData
// is as defined in PKCS #7 standard.
type PKCS7 struct {
Raw asn1.RawContent
ContentInfo string
Content Content
}
// Content implements three of the six possible PKCS7 data types. Only one is non-nil.
type Content struct {
Data []byte
SignedData SignedData
EncryptedData EncryptedData
}
// SignedData defines the typical carrier of certificates and crls.
type SignedData struct {
Raw asn1.RawContent
Version int
Certificates []*x509.Certificate
Crl *pkix.CertificateList
}
// Data contains raw bytes. Used as a subtype in PKCS12.
type Data struct {
Bytes []byte
}
// EncryptedData contains encrypted data. Used as a subtype in PKCS12.
type EncryptedData struct {
Raw asn1.RawContent
Version int
EncryptedContentInfo EncryptedContentInfo
}
// EncryptedContentInfo is a subtype of PKCS7EncryptedData.
type EncryptedContentInfo struct {
Raw asn1.RawContent
ContentType asn1.ObjectIdentifier
ContentEncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedContent []byte `asn1:"tag:0,optional"`
}
// ParsePKCS7 attempts to parse the DER encoded bytes of a
// PKCS7 structure.
func ParsePKCS7(raw []byte) (msg *PKCS7, err error) {
var pkcs7 initPKCS7
_, err = asn1.Unmarshal(raw, &pkcs7)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
msg = new(PKCS7)
msg.Raw = pkcs7.Raw
msg.ContentInfo = pkcs7.ContentType.String()
switch {
case msg.ContentInfo == ObjIDData:
msg.ContentInfo = "Data"
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &msg.Content.Data)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
case msg.ContentInfo == ObjIDSignedData:
msg.ContentInfo = "SignedData"
var signedData signedData
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &signedData)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
if len(signedData.Certificates.Bytes) != 0 {
msg.Content.SignedData.Certificates, err = x509.ParseCertificates(signedData.Certificates.Bytes)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
}
if len(signedData.Crls.Bytes) != 0 {
msg.Content.SignedData.Crl, err = x509.ParseDERCRL(signedData.Crls.Bytes)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
}
msg.Content.SignedData.Version = signedData.Version
msg.Content.SignedData.Raw = pkcs7.Content.Bytes
case msg.ContentInfo == ObjIDEncryptedData:
msg.ContentInfo = "EncryptedData"
var encryptedData EncryptedData
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &encryptedData)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
if encryptedData.Version != 0 {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, errors.New("Only support for PKCS #7 encryptedData version 0"))
}
msg.Content.EncryptedData = encryptedData
default:
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, errors.New("Attempt to parse PKCS# 7 Content not of type data, signed data or encrypted data"))
}
return msg, nil
}

View File

@ -1,438 +0,0 @@
// Package csr implements certificate requests for CFSSL.
package csr
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"net"
"net/mail"
"net/url"
"strings"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/helpers"
"github.com/cloudflare/cfssl/log"
)
const (
curveP256 = 256
curveP384 = 384
curveP521 = 521
)
// A Name contains the SubjectInfo fields.
type Name struct {
C string // Country
ST string // State
L string // Locality
O string // OrganisationName
OU string // OrganisationalUnitName
SerialNumber string
}
// A KeyRequest is a generic request for a new key.
type KeyRequest interface {
Algo() string
Size() int
Generate() (crypto.PrivateKey, error)
SigAlgo() x509.SignatureAlgorithm
}
// A BasicKeyRequest contains the algorithm and key size for a new private key.
type BasicKeyRequest struct {
A string `json:"algo" yaml:"algo"`
S int `json:"size" yaml:"size"`
}
// NewBasicKeyRequest returns a default BasicKeyRequest.
func NewBasicKeyRequest() *BasicKeyRequest {
return &BasicKeyRequest{"ecdsa", curveP256}
}
// Algo returns the requested key algorithm represented as a string.
func (kr *BasicKeyRequest) Algo() string {
return kr.A
}
// Size returns the requested key size.
func (kr *BasicKeyRequest) Size() int {
return kr.S
}
// Generate generates a key as specified in the request. Currently,
// only ECDSA and RSA are supported.
func (kr *BasicKeyRequest) Generate() (crypto.PrivateKey, error) {
log.Debugf("generate key from request: algo=%s, size=%d", kr.Algo(), kr.Size())
switch kr.Algo() {
case "rsa":
if kr.Size() < 2048 {
return nil, errors.New("RSA key is too weak")
}
if kr.Size() > 8192 {
return nil, errors.New("RSA key size too large")
}
return rsa.GenerateKey(rand.Reader, kr.Size())
case "ecdsa":
var curve elliptic.Curve
switch kr.Size() {
case curveP256:
curve = elliptic.P256()
case curveP384:
curve = elliptic.P384()
case curveP521:
curve = elliptic.P521()
default:
return nil, errors.New("invalid curve")
}
return ecdsa.GenerateKey(curve, rand.Reader)
default:
return nil, errors.New("invalid algorithm")
}
}
// SigAlgo returns an appropriate X.509 signature algorithm given the
// key request's type and size.
func (kr *BasicKeyRequest) SigAlgo() x509.SignatureAlgorithm {
switch kr.Algo() {
case "rsa":
switch {
case kr.Size() >= 4096:
return x509.SHA512WithRSA
case kr.Size() >= 3072:
return x509.SHA384WithRSA
case kr.Size() >= 2048:
return x509.SHA256WithRSA
default:
return x509.SHA1WithRSA
}
case "ecdsa":
switch kr.Size() {
case curveP521:
return x509.ECDSAWithSHA512
case curveP384:
return x509.ECDSAWithSHA384
case curveP256:
return x509.ECDSAWithSHA256
default:
return x509.ECDSAWithSHA1
}
default:
return x509.UnknownSignatureAlgorithm
}
}
// CAConfig is a section used in the requests initialising a new CA.
type CAConfig struct {
PathLength int `json:"pathlen" yaml:"pathlen"`
PathLenZero bool `json:"pathlenzero" yaml:"pathlenzero"`
Expiry string `json:"expiry" yaml:"expiry"`
Backdate string `json:"backdate" yaml:"backdate"`
}
// A CertificateRequest encapsulates the API interface to the
// certificate request functionality.
type CertificateRequest struct {
CN string
Names []Name `json:"names" yaml:"names"`
Hosts []string `json:"hosts" yaml:"hosts"`
KeyRequest KeyRequest `json:"key,omitempty" yaml:"key,omitempty"`
CA *CAConfig `json:"ca,omitempty" yaml:"ca,omitempty"`
SerialNumber string `json:"serialnumber,omitempty" yaml:"serialnumber,omitempty"`
}
// New returns a new, empty CertificateRequest with a
// BasicKeyRequest.
func New() *CertificateRequest {
return &CertificateRequest{
KeyRequest: NewBasicKeyRequest(),
}
}
// appendIf appends to a if s is not an empty string.
func appendIf(s string, a *[]string) {
if s != "" {
*a = append(*a, s)
}
}
// Name returns the PKIX name for the request.
func (cr *CertificateRequest) Name() pkix.Name {
var name pkix.Name
name.CommonName = cr.CN
for _, n := range cr.Names {
appendIf(n.C, &name.Country)
appendIf(n.ST, &name.Province)
appendIf(n.L, &name.Locality)
appendIf(n.O, &name.Organization)
appendIf(n.OU, &name.OrganizationalUnit)
}
name.SerialNumber = cr.SerialNumber
return name
}
// BasicConstraints CSR information RFC 5280, 4.2.1.9
type BasicConstraints struct {
IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"`
}
// ParseRequest takes a certificate request and generates a key and
// CSR from it. It does no validation -- caveat emptor. It will,
// however, fail if the key request is not valid (i.e., an unsupported
// curve or RSA key size). The lack of validation was specifically
// chosen to allow the end user to define a policy and validate the
// request appropriately before calling this function.
func ParseRequest(req *CertificateRequest) (csr, key []byte, err error) {
log.Info("received CSR")
if req.KeyRequest == nil {
req.KeyRequest = NewBasicKeyRequest()
}
log.Infof("generating key: %s-%d", req.KeyRequest.Algo(), req.KeyRequest.Size())
priv, err := req.KeyRequest.Generate()
if err != nil {
err = cferr.Wrap(cferr.PrivateKeyError, cferr.GenerationFailed, err)
return
}
switch priv := priv.(type) {
case *rsa.PrivateKey:
key = x509.MarshalPKCS1PrivateKey(priv)
block := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: key,
}
key = pem.EncodeToMemory(&block)
case *ecdsa.PrivateKey:
key, err = x509.MarshalECPrivateKey(priv)
if err != nil {
err = cferr.Wrap(cferr.PrivateKeyError, cferr.Unknown, err)
return
}
block := pem.Block{
Type: "EC PRIVATE KEY",
Bytes: key,
}
key = pem.EncodeToMemory(&block)
default:
panic("Generate should have failed to produce a valid key.")
}
csr, err = Generate(priv.(crypto.Signer), req)
if err != nil {
log.Errorf("failed to generate a CSR: %v", err)
err = cferr.Wrap(cferr.CSRError, cferr.BadRequest, err)
}
return
}
// ExtractCertificateRequest extracts a CertificateRequest from
// x509.Certificate. It is aimed to used for generating a new certificate
// from an existing certificate. For a root certificate, the CA expiry
// length is calculated as the duration between cert.NotAfter and cert.NotBefore.
func ExtractCertificateRequest(cert *x509.Certificate) *CertificateRequest {
req := New()
req.CN = cert.Subject.CommonName
req.Names = getNames(cert.Subject)
req.Hosts = getHosts(cert)
req.SerialNumber = cert.Subject.SerialNumber
if cert.IsCA {
req.CA = new(CAConfig)
// CA expiry length is calculated based on the input cert
// issue date and expiry date.
req.CA.Expiry = cert.NotAfter.Sub(cert.NotBefore).String()
req.CA.PathLength = cert.MaxPathLen
req.CA.PathLenZero = cert.MaxPathLenZero
}
return req
}
func getHosts(cert *x509.Certificate) []string {
var hosts []string
for _, ip := range cert.IPAddresses {
hosts = append(hosts, ip.String())
}
for _, dns := range cert.DNSNames {
hosts = append(hosts, dns)
}
for _, email := range cert.EmailAddresses {
hosts = append(hosts, email)
}
for _, uri := range cert.URIs {
hosts = append(hosts, uri.String())
}
return hosts
}
// getNames returns an array of Names from the certificate
// It onnly cares about Country, Organization, OrganizationalUnit, Locality, Province
func getNames(sub pkix.Name) []Name {
// anonymous func for finding the max of a list of interger
max := func(v1 int, vn ...int) (max int) {
max = v1
for i := 0; i < len(vn); i++ {
if vn[i] > max {
max = vn[i]
}
}
return max
}
nc := len(sub.Country)
norg := len(sub.Organization)
nou := len(sub.OrganizationalUnit)
nl := len(sub.Locality)
np := len(sub.Province)
n := max(nc, norg, nou, nl, np)
names := make([]Name, n)
for i := range names {
if i < nc {
names[i].C = sub.Country[i]
}
if i < norg {
names[i].O = sub.Organization[i]
}
if i < nou {
names[i].OU = sub.OrganizationalUnit[i]
}
if i < nl {
names[i].L = sub.Locality[i]
}
if i < np {
names[i].ST = sub.Province[i]
}
}
return names
}
// A Generator is responsible for validating certificate requests.
type Generator struct {
Validator func(*CertificateRequest) error
}
// ProcessRequest validates and processes the incoming request. It is
// a wrapper around a validator and the ParseRequest function.
func (g *Generator) ProcessRequest(req *CertificateRequest) (csr, key []byte, err error) {
log.Info("generate received request")
err = g.Validator(req)
if err != nil {
log.Warningf("invalid request: %v", err)
return nil, nil, err
}
csr, key, err = ParseRequest(req)
if err != nil {
return nil, nil, err
}
return
}
// IsNameEmpty returns true if the name has no identifying information in it.
func IsNameEmpty(n Name) bool {
empty := func(s string) bool { return strings.TrimSpace(s) == "" }
if empty(n.C) && empty(n.ST) && empty(n.L) && empty(n.O) && empty(n.OU) {
return true
}
return false
}
// Regenerate uses the provided CSR as a template for signing a new
// CSR using priv.
func Regenerate(priv crypto.Signer, csr []byte) ([]byte, error) {
req, extra, err := helpers.ParseCSR(csr)
if err != nil {
return nil, err
} else if len(extra) > 0 {
return nil, errors.New("csr: trailing data in certificate request")
}
return x509.CreateCertificateRequest(rand.Reader, req, priv)
}
// Generate creates a new CSR from a CertificateRequest structure and
// an existing key. The KeyRequest field is ignored.
func Generate(priv crypto.Signer, req *CertificateRequest) (csr []byte, err error) {
sigAlgo := helpers.SignerAlgo(priv)
if sigAlgo == x509.UnknownSignatureAlgorithm {
return nil, cferr.New(cferr.PrivateKeyError, cferr.Unavailable)
}
var tpl = x509.CertificateRequest{
Subject: req.Name(),
SignatureAlgorithm: sigAlgo,
}
for i := range req.Hosts {
if ip := net.ParseIP(req.Hosts[i]); ip != nil {
tpl.IPAddresses = append(tpl.IPAddresses, ip)
} else if email, err := mail.ParseAddress(req.Hosts[i]); err == nil && email != nil {
tpl.EmailAddresses = append(tpl.EmailAddresses, email.Address)
} else if uri, err := url.ParseRequestURI(req.Hosts[i]); err == nil && uri != nil {
tpl.URIs = append(tpl.URIs, uri)
} else {
tpl.DNSNames = append(tpl.DNSNames, req.Hosts[i])
}
}
if req.CA != nil {
err = appendCAInfoToCSR(req.CA, &tpl)
if err != nil {
err = cferr.Wrap(cferr.CSRError, cferr.GenerationFailed, err)
return
}
}
csr, err = x509.CreateCertificateRequest(rand.Reader, &tpl, priv)
if err != nil {
log.Errorf("failed to generate a CSR: %v", err)
err = cferr.Wrap(cferr.CSRError, cferr.BadRequest, err)
return
}
block := pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr,
}
log.Info("encoded CSR")
csr = pem.EncodeToMemory(&block)
return
}
// appendCAInfoToCSR appends CAConfig BasicConstraint extension to a CSR
func appendCAInfoToCSR(reqConf *CAConfig, csr *x509.CertificateRequest) error {
pathlen := reqConf.PathLength
if pathlen == 0 && !reqConf.PathLenZero {
pathlen = -1
}
val, err := asn1.Marshal(BasicConstraints{true, pathlen})
if err != nil {
return err
}
csr.ExtraExtensions = []pkix.Extension{
{
Id: asn1.ObjectIdentifier{2, 5, 29, 19},
Value: val,
Critical: true,
},
}
return nil
}

View File

@ -1,46 +0,0 @@
/*
Package errors provides error types returned in CF SSL.
1. Type Error is intended for errors produced by CF SSL packages.
It formats to a json object that consists of an error message and a 4-digit code for error reasoning.
Example: {"code":1002, "message": "Failed to decode certificate"}
The index of codes are listed below:
1XXX: CertificateError
1000: Unknown
1001: ReadFailed
1002: DecodeFailed
1003: ParseFailed
1100: SelfSigned
12XX: VerifyFailed
121X: CertificateInvalid
1210: NotAuthorizedToSign
1211: Expired
1212: CANotAuthorizedForThisName
1213: TooManyIntermediates
1214: IncompatibleUsage
1220: UnknownAuthority
2XXX: PrivatekeyError
2000: Unknown
2001: ReadFailed
2002: DecodeFailed
2003: ParseFailed
2100: Encrypted
2200: NotRSA
2300: KeyMismatch
2400: GenerationFailed
2500: Unavailable
3XXX: IntermediatesError
4XXX: RootError
5XXX: PolicyError
5100: NoKeyUsages
5200: InvalidPolicy
5300: InvalidRequest
5400: UnknownProfile
6XXX: DialError
2. Type HttpError is intended for CF SSL API to consume. It contains a HTTP status code that will be read and returned
by the API server.
*/
package errors

View File

@ -1,438 +0,0 @@
package errors
import (
"crypto/x509"
"encoding/json"
"fmt"
)
// Error is the error type usually returned by functions in CF SSL package.
// It contains a 4-digit error code where the most significant digit
// describes the category where the error occurred and the rest 3 digits
// describe the specific error reason.
type Error struct {
ErrorCode int `json:"code"`
Message string `json:"message"`
}
// Category is the most significant digit of the error code.
type Category int
// Reason is the last 3 digits of the error code.
type Reason int
const (
// Success indicates no error occurred.
Success Category = 1000 * iota // 0XXX
// CertificateError indicates a fault in a certificate.
CertificateError // 1XXX
// PrivateKeyError indicates a fault in a private key.
PrivateKeyError // 2XXX
// IntermediatesError indicates a fault in an intermediate.
IntermediatesError // 3XXX
// RootError indicates a fault in a root.
RootError // 4XXX
// PolicyError indicates an error arising from a malformed or
// non-existent policy, or a breach of policy.
PolicyError // 5XXX
// DialError indicates a network fault.
DialError // 6XXX
// APIClientError indicates a problem with the API client.
APIClientError // 7XXX
// OCSPError indicates a problem with OCSP signing
OCSPError // 8XXX
// CSRError indicates a problem with CSR parsing
CSRError // 9XXX
// CTError indicates a problem with the certificate transparency process
CTError // 10XXX
// CertStoreError indicates a problem with the certificate store
CertStoreError // 11XXX
)
// None is a non-specified error.
const (
None Reason = iota
)
// Warning code for a success
const (
BundleExpiringBit int = 1 << iota // 0x01
BundleNotUbiquitousBit // 0x02
)
// Parsing errors
const (
Unknown Reason = iota // X000
ReadFailed // X001
DecodeFailed // X002
ParseFailed // X003
)
// The following represent certificate non-parsing errors, and must be
// specified along with CertificateError.
const (
// SelfSigned indicates that a certificate is self-signed and
// cannot be used in the manner being attempted.
SelfSigned Reason = 100 * (iota + 1) // Code 11XX
// VerifyFailed is an X.509 verification failure. The least two
// significant digits of 12XX is determined as the actual x509
// error is examined.
VerifyFailed // Code 12XX
// BadRequest indicates that the certificate request is invalid.
BadRequest // Code 13XX
// MissingSerial indicates that the profile specified
// 'ClientProvidesSerialNumbers', but the SignRequest did not include a serial
// number.
MissingSerial // Code 14XX
)
const (
certificateInvalid = 10 * (iota + 1) //121X
unknownAuthority //122x
)
// The following represent private-key non-parsing errors, and must be
// specified with PrivateKeyError.
const (
// Encrypted indicates that the private key is a PKCS #8 encrypted
// private key. At this time, CFSSL does not support decrypting
// these keys.
Encrypted Reason = 100 * (iota + 1) //21XX
// NotRSAOrECC indicates that they key is not an RSA or ECC
// private key; these are the only two private key types supported
// at this time by CFSSL.
NotRSAOrECC //22XX
// KeyMismatch indicates that the private key does not match
// the public key or certificate being presented with the key.
KeyMismatch //23XX
// GenerationFailed indicates that a private key could not
// be generated.
GenerationFailed //24XX
// Unavailable indicates that a private key mechanism (such as
// PKCS #11) was requested but support for that mechanism is
// not available.
Unavailable
)
// The following are policy-related non-parsing errors, and must be
// specified along with PolicyError.
const (
// NoKeyUsages indicates that the profile does not permit any
// key usages for the certificate.
NoKeyUsages Reason = 100 * (iota + 1) // 51XX
// InvalidPolicy indicates that policy being requested is not
// a valid policy or does not exist.
InvalidPolicy // 52XX
// InvalidRequest indicates a certificate request violated the
// constraints of the policy being applied to the request.
InvalidRequest // 53XX
// UnknownProfile indicates that the profile does not exist.
UnknownProfile // 54XX
UnmatchedWhitelist // 55xx
)
// The following are API client related errors, and should be
// specified with APIClientError.
const (
// AuthenticationFailure occurs when the client is unable
// to obtain an authentication token for the request.
AuthenticationFailure Reason = 100 * (iota + 1)
// JSONError wraps an encoding/json error.
JSONError
// IOError wraps an io/ioutil error.
IOError
// ClientHTTPError wraps a net/http error.
ClientHTTPError
// ServerRequestFailed covers any other failures from the API
// client.
ServerRequestFailed
)
// The following are OCSP related errors, and should be
// specified with OCSPError
const (
// IssuerMismatch ocurs when the certificate in the OCSP signing
// request was not issued by the CA that this responder responds for.
IssuerMismatch Reason = 100 * (iota + 1) // 81XX
// InvalidStatus occurs when the OCSP signing requests includes an
// invalid value for the certificate status.
InvalidStatus
)
// Certificate transparency related errors specified with CTError
const (
// PrecertSubmissionFailed occurs when submitting a precertificate to
// a log server fails
PrecertSubmissionFailed = 100 * (iota + 1)
// CTClientConstructionFailed occurs when the construction of a new
// github.com/google/certificate-transparency client fails.
CTClientConstructionFailed
// PrecertMissingPoison occurs when a precert is passed to SignFromPrecert
// and is missing the CT poison extension.
PrecertMissingPoison
// PrecertInvalidPoison occurs when a precert is passed to SignFromPrecert
// and has a invalid CT poison extension value or the extension is not
// critical.
PrecertInvalidPoison
)
// Certificate persistence related errors specified with CertStoreError
const (
// InsertionFailed occurs when a SQL insert query failes to complete.
InsertionFailed = 100 * (iota + 1)
// RecordNotFound occurs when a SQL query targeting on one unique
// record failes to update the specified row in the table.
RecordNotFound
)
// The error interface implementation, which formats to a JSON object string.
func (e *Error) Error() string {
marshaled, err := json.Marshal(e)
if err != nil {
panic(err)
}
return string(marshaled)
}
// New returns an error that contains an error code and message derived from
// the given category, reason. Currently, to avoid confusion, it is not
// allowed to create an error of category Success
func New(category Category, reason Reason) *Error {
errorCode := int(category) + int(reason)
var msg string
switch category {
case OCSPError:
switch reason {
case ReadFailed:
msg = "No certificate provided"
case IssuerMismatch:
msg = "Certificate not issued by this issuer"
case InvalidStatus:
msg = "Invalid revocation status"
}
case CertificateError:
switch reason {
case Unknown:
msg = "Unknown certificate error"
case ReadFailed:
msg = "Failed to read certificate"
case DecodeFailed:
msg = "Failed to decode certificate"
case ParseFailed:
msg = "Failed to parse certificate"
case SelfSigned:
msg = "Certificate is self signed"
case VerifyFailed:
msg = "Unable to verify certificate"
case BadRequest:
msg = "Invalid certificate request"
case MissingSerial:
msg = "Missing serial number in request"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category CertificateError.",
reason))
}
case PrivateKeyError:
switch reason {
case Unknown:
msg = "Unknown private key error"
case ReadFailed:
msg = "Failed to read private key"
case DecodeFailed:
msg = "Failed to decode private key"
case ParseFailed:
msg = "Failed to parse private key"
case Encrypted:
msg = "Private key is encrypted."
case NotRSAOrECC:
msg = "Private key algorithm is not RSA or ECC"
case KeyMismatch:
msg = "Private key does not match public key"
case GenerationFailed:
msg = "Failed to new private key"
case Unavailable:
msg = "Private key is unavailable"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category PrivateKeyError.",
reason))
}
case IntermediatesError:
switch reason {
case Unknown:
msg = "Unknown intermediate certificate error"
case ReadFailed:
msg = "Failed to read intermediate certificate"
case DecodeFailed:
msg = "Failed to decode intermediate certificate"
case ParseFailed:
msg = "Failed to parse intermediate certificate"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category IntermediatesError.",
reason))
}
case RootError:
switch reason {
case Unknown:
msg = "Unknown root certificate error"
case ReadFailed:
msg = "Failed to read root certificate"
case DecodeFailed:
msg = "Failed to decode root certificate"
case ParseFailed:
msg = "Failed to parse root certificate"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category RootError.",
reason))
}
case PolicyError:
switch reason {
case Unknown:
msg = "Unknown policy error"
case NoKeyUsages:
msg = "Invalid policy: no key usage available"
case InvalidPolicy:
msg = "Invalid or unknown policy"
case InvalidRequest:
msg = "Policy violation request"
case UnknownProfile:
msg = "Unknown policy profile"
case UnmatchedWhitelist:
msg = "Request does not match policy whitelist"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category PolicyError.",
reason))
}
case DialError:
switch reason {
case Unknown:
msg = "Failed to dial remote server"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category DialError.",
reason))
}
case APIClientError:
switch reason {
case AuthenticationFailure:
msg = "API client authentication failure"
case JSONError:
msg = "API client JSON config error"
case ClientHTTPError:
msg = "API client HTTP error"
case IOError:
msg = "API client IO error"
case ServerRequestFailed:
msg = "API client error: Server request failed"
default:
panic(fmt.Sprintf("Unsupported CFSSL error reason %d under category APIClientError.",
reason))
}
case CSRError:
switch reason {
case Unknown:
msg = "CSR parsing failed due to unknown error"
case ReadFailed:
msg = "CSR file read failed"
case ParseFailed:
msg = "CSR Parsing failed"
case DecodeFailed:
msg = "CSR Decode failed"
case BadRequest:
msg = "CSR Bad request"
default:
panic(fmt.Sprintf("Unsupported CF-SSL error reason %d under category APIClientError.", reason))
}
case CTError:
switch reason {
case Unknown:
msg = "Certificate transparency parsing failed due to unknown error"
case PrecertSubmissionFailed:
msg = "Certificate transparency precertificate submission failed"
case PrecertMissingPoison:
msg = "Precertificate is missing CT poison extension"
case PrecertInvalidPoison:
msg = "Precertificate contains an invalid CT poison extension"
default:
panic(fmt.Sprintf("Unsupported CF-SSL error reason %d under category CTError.", reason))
}
case CertStoreError:
switch reason {
case Unknown:
msg = "Certificate store action failed due to unknown error"
default:
panic(fmt.Sprintf("Unsupported CF-SSL error reason %d under category CertStoreError.", reason))
}
default:
panic(fmt.Sprintf("Unsupported CFSSL error type: %d.",
category))
}
return &Error{ErrorCode: errorCode, Message: msg}
}
// Wrap returns an error that contains the given error and an error code derived from
// the given category, reason and the error. Currently, to avoid confusion, it is not
// allowed to create an error of category Success
func Wrap(category Category, reason Reason, err error) *Error {
errorCode := int(category) + int(reason)
if err == nil {
panic("Wrap needs a supplied error to initialize.")
}
// do not double wrap a error
switch err.(type) {
case *Error:
panic("Unable to wrap a wrapped error.")
}
switch category {
case CertificateError:
// given VerifyFailed , report the status with more detailed status code
// for some certificate errors we care.
if reason == VerifyFailed {
switch errorType := err.(type) {
case x509.CertificateInvalidError:
errorCode += certificateInvalid + int(errorType.Reason)
case x509.UnknownAuthorityError:
errorCode += unknownAuthority
}
}
case PrivateKeyError, IntermediatesError, RootError, PolicyError, DialError,
APIClientError, CSRError, CTError, CertStoreError, OCSPError:
// no-op, just use the error
default:
panic(fmt.Sprintf("Unsupported CFSSL error type: %d.",
category))
}
return &Error{ErrorCode: errorCode, Message: err.Error()}
}

View File

@ -1,47 +0,0 @@
package errors
import (
"errors"
"net/http"
)
// HTTPError is an augmented error with a HTTP status code.
type HTTPError struct {
StatusCode int
error
}
// Error implements the error interface.
func (e *HTTPError) Error() string {
return e.error.Error()
}
// NewMethodNotAllowed returns an appropriate error in the case that
// an HTTP client uses an invalid method (i.e. a GET in place of a POST)
// on an API endpoint.
func NewMethodNotAllowed(method string) *HTTPError {
return &HTTPError{http.StatusMethodNotAllowed, errors.New(`Method is not allowed:"` + method + `"`)}
}
// NewBadRequest creates a HttpError with the given error and error code 400.
func NewBadRequest(err error) *HTTPError {
return &HTTPError{http.StatusBadRequest, err}
}
// NewBadRequestString returns a HttpError with the supplied message
// and error code 400.
func NewBadRequestString(s string) *HTTPError {
return NewBadRequest(errors.New(s))
}
// NewBadRequestMissingParameter returns a 400 HttpError as a required
// parameter is missing in the HTTP request.
func NewBadRequestMissingParameter(s string) *HTTPError {
return NewBadRequestString(`Missing parameter "` + s + `"`)
}
// NewBadRequestUnwantedParameter returns a 400 HttpError as a unnecessary
// parameter is present in the HTTP request.
func NewBadRequestUnwantedParameter(s string) *HTTPError {
return NewBadRequestString(`Unwanted parameter "` + s + `"`)
}

View File

@ -1,48 +0,0 @@
// Package derhelpers implements common functionality
// on DER encoded data
package derhelpers
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
cferr "github.com/cloudflare/cfssl/errors"
"golang.org/x/crypto/ed25519"
)
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, ECDSA, or Ed25519 DER-encoded
// private key. The key must not be in PEM format.
func ParsePrivateKeyDER(keyDER []byte) (key crypto.Signer, err error) {
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER)
if err != nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER)
if err != nil {
generalKey, err = x509.ParseECPrivateKey(keyDER)
if err != nil {
generalKey, err = ParseEd25519PrivateKey(keyDER)
if err != nil {
// We don't include the actual error into
// the final error. The reason might be
// we don't want to leak any info about
// the private key.
return nil, cferr.New(cferr.PrivateKeyError,
cferr.ParseFailed)
}
}
}
}
switch generalKey.(type) {
case *rsa.PrivateKey:
return generalKey.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return generalKey.(*ecdsa.PrivateKey), nil
case ed25519.PrivateKey:
return generalKey.(ed25519.PrivateKey), nil
}
// should never reach here
return nil, cferr.New(cferr.PrivateKeyError, cferr.ParseFailed)
}

View File

@ -1,133 +0,0 @@
package derhelpers
import (
"crypto"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"golang.org/x/crypto/ed25519"
)
var errEd25519WrongID = errors.New("incorrect object identifier")
var errEd25519WrongKeyType = errors.New("incorrect key type")
// ed25519OID is the OID for the Ed25519 signature scheme: see
// https://datatracker.ietf.org/doc/draft-ietf-curdle-pkix-04.
var ed25519OID = asn1.ObjectIdentifier{1, 3, 101, 112}
// subjectPublicKeyInfo reflects the ASN.1 object defined in the X.509 standard.
//
// This is defined in crypto/x509 as "publicKeyInfo".
type subjectPublicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
PublicKey asn1.BitString
}
// MarshalEd25519PublicKey creates a DER-encoded SubjectPublicKeyInfo for an
// ed25519 public key, as defined in
// https://tools.ietf.org/html/draft-ietf-curdle-pkix-04. This is analagous to
// MarshalPKIXPublicKey in crypto/x509, which doesn't currently support Ed25519.
func MarshalEd25519PublicKey(pk crypto.PublicKey) ([]byte, error) {
pub, ok := pk.(ed25519.PublicKey)
if !ok {
return nil, errEd25519WrongKeyType
}
spki := subjectPublicKeyInfo{
Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID,
},
PublicKey: asn1.BitString{
BitLength: len(pub) * 8,
Bytes: pub,
},
}
return asn1.Marshal(spki)
}
// ParseEd25519PublicKey returns the Ed25519 public key encoded by the input.
func ParseEd25519PublicKey(der []byte) (crypto.PublicKey, error) {
var spki subjectPublicKeyInfo
if rest, err := asn1.Unmarshal(der, &spki); err != nil {
return nil, err
} else if len(rest) > 0 {
return nil, errors.New("SubjectPublicKeyInfo too long")
}
if !spki.Algorithm.Algorithm.Equal(ed25519OID) {
return nil, errEd25519WrongID
}
if spki.PublicKey.BitLength != ed25519.PublicKeySize*8 {
return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch")
}
return ed25519.PublicKey(spki.PublicKey.Bytes), nil
}
// oneAsymmetricKey reflects the ASN.1 structure for storing private keys in
// https://tools.ietf.org/html/draft-ietf-curdle-pkix-04, excluding the optional
// fields, which we don't use here.
//
// This is identical to pkcs8 in crypto/x509.
type oneAsymmetricKey struct {
Version int
Algorithm pkix.AlgorithmIdentifier
PrivateKey []byte
}
// curvePrivateKey is the innter type of the PrivateKey field of
// oneAsymmetricKey.
type curvePrivateKey []byte
// MarshalEd25519PrivateKey returns a DER encdoing of the input private key as
// specified in https://tools.ietf.org/html/draft-ietf-curdle-pkix-04.
func MarshalEd25519PrivateKey(sk crypto.PrivateKey) ([]byte, error) {
priv, ok := sk.(ed25519.PrivateKey)
if !ok {
return nil, errEd25519WrongKeyType
}
// Marshal the innter CurvePrivateKey.
curvePrivateKey, err := asn1.Marshal(priv.Seed())
if err != nil {
return nil, err
}
// Marshal the OneAsymmetricKey.
asym := oneAsymmetricKey{
Version: 0,
Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID,
},
PrivateKey: curvePrivateKey,
}
return asn1.Marshal(asym)
}
// ParseEd25519PrivateKey returns the Ed25519 private key encoded by the input.
func ParseEd25519PrivateKey(der []byte) (crypto.PrivateKey, error) {
asym := new(oneAsymmetricKey)
if rest, err := asn1.Unmarshal(der, asym); err != nil {
return nil, err
} else if len(rest) > 0 {
return nil, errors.New("OneAsymmetricKey too long")
}
// Check that the key type is correct.
if !asym.Algorithm.Algorithm.Equal(ed25519OID) {
return nil, errEd25519WrongID
}
// Unmarshal the inner CurvePrivateKey.
seed := new(curvePrivateKey)
if rest, err := asn1.Unmarshal(asym.PrivateKey, seed); err != nil {
return nil, err
} else if len(rest) > 0 {
return nil, errors.New("CurvePrivateKey too long")
}
return ed25519.NewKeyFromSeed(*seed), nil
}

View File

@ -1,590 +0,0 @@
// Package helpers implements utility functionality common to many
// CFSSL packages.
package helpers
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"os"
"github.com/google/certificate-transparency-go"
cttls "github.com/google/certificate-transparency-go/tls"
ctx509 "github.com/google/certificate-transparency-go/x509"
"golang.org/x/crypto/ocsp"
"strings"
"time"
"github.com/cloudflare/cfssl/crypto/pkcs7"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/helpers/derhelpers"
"github.com/cloudflare/cfssl/log"
"golang.org/x/crypto/pkcs12"
)
// OneYear is a time.Duration representing a year's worth of seconds.
const OneYear = 8760 * time.Hour
// OneDay is a time.Duration representing a day's worth of seconds.
const OneDay = 24 * time.Hour
// InclusiveDate returns the time.Time representation of a date - 1
// nanosecond. This allows time.After to be used inclusively.
func InclusiveDate(year int, month time.Month, day int) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond)
}
// Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 5 years.
var Jul2012 = InclusiveDate(2012, time.July, 01)
// Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 39 months.
var Apr2015 = InclusiveDate(2015, time.April, 01)
// KeyLength returns the bit size of ECDSA or RSA PublicKey
func KeyLength(key interface{}) int {
if key == nil {
return 0
}
if ecdsaKey, ok := key.(*ecdsa.PublicKey); ok {
return ecdsaKey.Curve.Params().BitSize
} else if rsaKey, ok := key.(*rsa.PublicKey); ok {
return rsaKey.N.BitLen()
}
return 0
}
// ExpiryTime returns the time when the certificate chain is expired.
func ExpiryTime(chain []*x509.Certificate) (notAfter time.Time) {
if len(chain) == 0 {
return
}
notAfter = chain[0].NotAfter
for _, cert := range chain {
if notAfter.After(cert.NotAfter) {
notAfter = cert.NotAfter
}
}
return
}
// MonthsValid returns the number of months for which a certificate is valid.
func MonthsValid(c *x509.Certificate) int {
issued := c.NotBefore
expiry := c.NotAfter
years := (expiry.Year() - issued.Year())
months := years*12 + int(expiry.Month()) - int(issued.Month())
// Round up if valid for less than a full month
if expiry.Day() > issued.Day() {
months++
}
return months
}
// ValidExpiry determines if a certificate is valid for an acceptable
// length of time per the CA/Browser Forum baseline requirements.
// See https://cabforum.org/wp-content/uploads/CAB-Forum-BR-1.3.0.pdf
func ValidExpiry(c *x509.Certificate) bool {
issued := c.NotBefore
var maxMonths int
switch {
case issued.After(Apr2015):
maxMonths = 39
case issued.After(Jul2012):
maxMonths = 60
case issued.Before(Jul2012):
maxMonths = 120
}
if MonthsValid(c) > maxMonths {
return false
}
return true
}
// SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm.
func SignatureString(alg x509.SignatureAlgorithm) string {
switch alg {
case x509.MD2WithRSA:
return "MD2WithRSA"
case x509.MD5WithRSA:
return "MD5WithRSA"
case x509.SHA1WithRSA:
return "SHA1WithRSA"
case x509.SHA256WithRSA:
return "SHA256WithRSA"
case x509.SHA384WithRSA:
return "SHA384WithRSA"
case x509.SHA512WithRSA:
return "SHA512WithRSA"
case x509.DSAWithSHA1:
return "DSAWithSHA1"
case x509.DSAWithSHA256:
return "DSAWithSHA256"
case x509.ECDSAWithSHA1:
return "ECDSAWithSHA1"
case x509.ECDSAWithSHA256:
return "ECDSAWithSHA256"
case x509.ECDSAWithSHA384:
return "ECDSAWithSHA384"
case x509.ECDSAWithSHA512:
return "ECDSAWithSHA512"
default:
return "Unknown Signature"
}
}
// HashAlgoString returns the hash algorithm name contains in the signature
// method.
func HashAlgoString(alg x509.SignatureAlgorithm) string {
switch alg {
case x509.MD2WithRSA:
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA:
return "SHA1"
case x509.SHA256WithRSA:
return "SHA256"
case x509.SHA384WithRSA:
return "SHA384"
case x509.SHA512WithRSA:
return "SHA512"
case x509.DSAWithSHA1:
return "SHA1"
case x509.DSAWithSHA256:
return "SHA256"
case x509.ECDSAWithSHA1:
return "SHA1"
case x509.ECDSAWithSHA256:
return "SHA256"
case x509.ECDSAWithSHA384:
return "SHA384"
case x509.ECDSAWithSHA512:
return "SHA512"
default:
return "Unknown Hash Algorithm"
}
}
// StringTLSVersion returns underlying enum values from human names for TLS
// versions, defaults to current golang default of TLS 1.0
func StringTLSVersion(version string) uint16 {
switch version {
case "1.2":
return tls.VersionTLS12
case "1.1":
return tls.VersionTLS11
default:
return tls.VersionTLS10
}
}
// EncodeCertificatesPEM encodes a number of x509 certificates to PEM
func EncodeCertificatesPEM(certs []*x509.Certificate) []byte {
var buffer bytes.Buffer
for _, cert := range certs {
pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
}
return buffer.Bytes()
}
// EncodeCertificatePEM encodes a single x509 certificates to PEM
func EncodeCertificatePEM(cert *x509.Certificate) []byte {
return EncodeCertificatesPEM([]*x509.Certificate{cert})
}
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
// can handle PEM encoded PKCS #7 structures.
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
var err error
certsPEM = bytes.TrimSpace(certsPEM)
for len(certsPEM) > 0 {
var cert []*x509.Certificate
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
if err != nil {
return nil, cferr.New(cferr.CertificateError, cferr.ParseFailed)
} else if cert == nil {
break
}
certs = append(certs, cert...)
}
if len(certsPEM) > 0 {
return nil, cferr.New(cferr.CertificateError, cferr.DecodeFailed)
}
return certs, nil
}
// ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key,
// either PKCS #7, PKCS #12, or raw x509.
func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certificate, key crypto.Signer, err error) {
certsDER = bytes.TrimSpace(certsDER)
pkcs7data, err := pkcs7.ParsePKCS7(certsDER)
if err != nil {
var pkcs12data interface{}
certs = make([]*x509.Certificate, 1)
pkcs12data, certs[0], err = pkcs12.Decode(certsDER, password)
if err != nil {
certs, err = x509.ParseCertificates(certsDER)
if err != nil {
return nil, nil, cferr.New(cferr.CertificateError, cferr.DecodeFailed)
}
} else {
key = pkcs12data.(crypto.Signer)
}
} else {
if pkcs7data.ContentInfo != "SignedData" {
return nil, nil, cferr.Wrap(cferr.CertificateError, cferr.DecodeFailed, errors.New("can only extract certificates from signed data content info"))
}
certs = pkcs7data.Content.SignedData.Certificates
}
if certs == nil {
return nil, key, cferr.New(cferr.CertificateError, cferr.DecodeFailed)
}
return certs, key, nil
}
// ParseSelfSignedCertificatePEM parses a PEM-encoded certificate and check if it is self-signed.
func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
cert, err := ParseCertificatePEM(certPEM)
if err != nil {
return nil, err
}
if err := cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.VerifyFailed, err)
}
return cert, nil
}
// ParseCertificatePEM parses and returns a PEM-encoded certificate,
// can handle PEM encoded PKCS #7 structures.
func ParseCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
certPEM = bytes.TrimSpace(certPEM)
cert, rest, err := ParseOneCertificateFromPEM(certPEM)
if err != nil {
// Log the actual parsing error but throw a default parse error message.
log.Debugf("Certificate parsing error: %v", err)
return nil, cferr.New(cferr.CertificateError, cferr.ParseFailed)
} else if cert == nil {
return nil, cferr.New(cferr.CertificateError, cferr.DecodeFailed)
} else if len(rest) > 0 {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, errors.New("the PEM file should contain only one object"))
} else if len(cert) > 1 {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, errors.New("the PKCS7 object in the PEM file should contain only one certificate"))
}
return cert[0], nil
}
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
// either a raw x509 certificate or a PKCS #7 structure possibly containing
// multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects.
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM)
if block == nil {
return nil, rest, nil
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
pkcs7data, err := pkcs7.ParsePKCS7(block.Bytes)
if err != nil {
return nil, rest, err
}
if pkcs7data.ContentInfo != "SignedData" {
return nil, rest, errors.New("only PKCS #7 Signed Data Content Info supported for certificate parsing")
}
certs := pkcs7data.Content.SignedData.Certificates
if certs == nil {
return nil, rest, errors.New("PKCS #7 structure contains no certificates")
}
return certs, rest, nil
}
var certs = []*x509.Certificate{cert}
return certs, rest, nil
}
// LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" {
return nil, nil
}
pemCerts, err := ioutil.ReadFile(certsFile)
if err != nil {
return nil, err
}
return PEMToCertPool(pemCerts)
}
// PEMToCertPool concerts PEM certificates to a CertPool.
func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
if len(pemCerts) == 0 {
return nil, nil
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(pemCerts) {
return nil, errors.New("failed to load cert pool")
}
return certPool, nil
}
// ParsePrivateKeyPEM parses and returns a PEM-encoded private
// key. The private key may be either an unencrypted PKCS#8, PKCS#1,
// or elliptic private key.
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
return ParsePrivateKeyPEMWithPassword(keyPEM, nil)
}
// ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private
// key. The private key may be a potentially encrypted PKCS#8, PKCS#1,
// or elliptic private key.
func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (key crypto.Signer, err error) {
keyDER, err := GetKeyDERFromPEM(keyPEM, password)
if err != nil {
return nil, err
}
return derhelpers.ParsePrivateKeyDER(keyDER)
}
// GetKeyDERFromPEM parses a PEM-encoded private key and returns DER-format key bytes.
func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
keyDER, _ := pem.Decode(in)
if keyDER != nil {
if procType, ok := keyDER.Headers["Proc-Type"]; ok {
if strings.Contains(procType, "ENCRYPTED") {
if password != nil {
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, cferr.New(cferr.PrivateKeyError, cferr.Encrypted)
}
}
return keyDER.Bytes, nil
}
return nil, cferr.New(cferr.PrivateKeyError, cferr.DecodeFailed)
}
// ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request.
func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error) {
in = bytes.TrimSpace(in)
p, rest := pem.Decode(in)
if p != nil {
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" {
return nil, rest, cferr.New(cferr.CSRError, cferr.BadRequest)
}
csr, err = x509.ParseCertificateRequest(p.Bytes)
} else {
csr, err = x509.ParseCertificateRequest(in)
}
if err != nil {
return nil, rest, err
}
err = csr.CheckSignature()
if err != nil {
return nil, rest, err
}
return csr, rest, nil
}
// ParseCSRPEM parses a PEM-encoded certificate signing request.
// It does not check the signature. This is useful for dumping data from a CSR
// locally.
func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM))
if block == nil {
return nil, cferr.New(cferr.CSRError, cferr.DecodeFailed)
}
csrObject, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, err
}
return csrObject, nil
}
// SignerAlgo returns an X.509 signature algorithm from a crypto.Signer.
func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
switch pub := priv.Public().(type) {
case *rsa.PublicKey:
bitLength := pub.N.BitLen()
switch {
case bitLength >= 4096:
return x509.SHA512WithRSA
case bitLength >= 3072:
return x509.SHA384WithRSA
case bitLength >= 2048:
return x509.SHA256WithRSA
default:
return x509.SHA1WithRSA
}
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P521():
return x509.ECDSAWithSHA512
case elliptic.P384():
return x509.ECDSAWithSHA384
case elliptic.P256():
return x509.ECDSAWithSHA256
default:
return x509.ECDSAWithSHA1
}
default:
return x509.UnknownSignatureAlgorithm
}
}
// LoadClientCertificate load key/certificate from pem files
func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, error) {
if certFile != "" && keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Critical("Unable to read client certificate from file: %s or key from file: %s", certFile, keyFile)
return nil, err
}
log.Debug("Client certificate loaded ")
return &cert, nil
}
return nil, nil
}
// CreateTLSConfig creates a tls.Config object from certs and roots
func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Config {
var certs []tls.Certificate
if cert != nil {
certs = []tls.Certificate{*cert}
}
return &tls.Config{
Certificates: certs,
RootCAs: remoteCAs,
}
}
// SerializeSCTList serializes a list of SCTs.
func SerializeSCTList(sctList []ct.SignedCertificateTimestamp) ([]byte, error) {
list := ctx509.SignedCertificateTimestampList{}
for _, sct := range sctList {
sctBytes, err := cttls.Marshal(sct)
if err != nil {
return nil, err
}
list.SCTList = append(list.SCTList, ctx509.SerializedSCT{Val: sctBytes})
}
return cttls.Marshal(list)
}
// DeserializeSCTList deserializes a list of SCTs.
func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimestamp, error) {
var sctList ctx509.SignedCertificateTimestampList
rest, err := cttls.Unmarshal(serializedSCTList, &sctList)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, errors.New("serialized SCT list contained trailing garbage"))
}
list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList))
for i, serializedSCT := range sctList.SCTList {
var sct ct.SignedCertificateTimestamp
rest, err := cttls.Unmarshal(serializedSCT.Val, &sct)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, errors.New("serialized SCT contained trailing garbage"))
}
list[i] = sct
}
return list, nil
}
// SCTListFromOCSPResponse extracts the SCTList from an ocsp.Response,
// returning an empty list if the SCT extension was not found or could not be
// unmarshalled.
func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) {
// This loop finds the SCTListExtension in the OCSP response.
var SCTListExtension, ext pkix.Extension
for _, ext = range response.Extensions {
// sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp.
sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
if ext.Id.Equal(sctExtOid) {
SCTListExtension = ext
break
}
}
// This code block extracts the sctList from the SCT extension.
var sctList []ct.SignedCertificateTimestamp
var err error
if numBytes := len(SCTListExtension.Value); numBytes != 0 {
var serializedSCTList []byte
rest := make([]byte, numBytes)
copy(rest, SCTListExtension.Value)
for len(rest) != 0 {
rest, err = asn1.Unmarshal(rest, &serializedSCTList)
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, err)
}
}
sctList, err = DeserializeSCTList(serializedSCTList)
}
return sctList, err
}
// ReadBytes reads a []byte either from a file or an environment variable.
// If valFile has a prefix of 'env:', the []byte is read from the environment
// using the subsequent name. If the prefix is 'file:' the []byte is read from
// the subsequent file. If no prefix is provided, valFile is assumed to be a
// file path.
func ReadBytes(valFile string) ([]byte, error) {
switch splitVal := strings.SplitN(valFile, ":", 2); len(splitVal) {
case 1:
return ioutil.ReadFile(valFile)
case 2:
switch splitVal[0] {
case "env":
return []byte(os.Getenv(splitVal[1])), nil
case "file":
return ioutil.ReadFile(splitVal[1])
default:
return nil, fmt.Errorf("unknown prefix: %s", splitVal[0])
}
default:
return nil, fmt.Errorf("multiple prefixes: %s",
strings.Join(splitVal[:len(splitVal)-1], ", "))
}
}

View File

@ -1,15 +0,0 @@
// Package info contains the definitions for the info endpoint
package info
// Req is the request struct for an info API request.
type Req struct {
Label string `json:"label"`
Profile string `json:"profile"`
}
// Resp is the response for an Info API request.
type Resp struct {
Certificate string `json:"certificate"`
Usage []string `json:"usages"`
ExpiryString string `json:"expiry"`
}

View File

@ -1,249 +0,0 @@
// Package initca contains code to initialise a certificate authority,
// generating a new root key and certificate.
package initca
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"time"
"github.com/cloudflare/cfssl/config"
"github.com/cloudflare/cfssl/csr"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/helpers"
"github.com/cloudflare/cfssl/log"
"github.com/cloudflare/cfssl/signer"
"github.com/cloudflare/cfssl/signer/local"
)
// validator contains the default validation logic for certificate
// authority certificates. The only requirement here is that the
// certificate have a non-empty subject field.
func validator(req *csr.CertificateRequest) error {
if req.CN != "" {
return nil
}
if len(req.Names) == 0 {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidRequest, errors.New("missing subject information"))
}
for i := range req.Names {
if csr.IsNameEmpty(req.Names[i]) {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidRequest, errors.New("missing subject information"))
}
}
return nil
}
// New creates a new root certificate from the certificate request.
func New(req *csr.CertificateRequest) (cert, csrPEM, key []byte, err error) {
policy := CAPolicy()
if req.CA != nil {
if req.CA.Expiry != "" {
policy.Default.ExpiryString = req.CA.Expiry
policy.Default.Expiry, err = time.ParseDuration(req.CA.Expiry)
if err != nil {
return
}
}
if req.CA.Backdate != "" {
policy.Default.Backdate, err = time.ParseDuration(req.CA.Backdate)
if err != nil {
return
}
}
policy.Default.CAConstraint.MaxPathLen = req.CA.PathLength
if req.CA.PathLength != 0 && req.CA.PathLenZero {
log.Infof("ignore invalid 'pathlenzero' value")
} else {
policy.Default.CAConstraint.MaxPathLenZero = req.CA.PathLenZero
}
}
g := &csr.Generator{Validator: validator}
csrPEM, key, err = g.ProcessRequest(req)
if err != nil {
log.Errorf("failed to process request: %v", err)
key = nil
return
}
priv, err := helpers.ParsePrivateKeyPEM(key)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
return
}
s, err := local.NewSigner(priv, nil, signer.DefaultSigAlgo(priv), policy)
if err != nil {
log.Errorf("failed to create signer: %v", err)
return
}
signReq := signer.SignRequest{Hosts: req.Hosts, Request: string(csrPEM)}
cert, err = s.Sign(signReq)
return
}
// NewFromPEM creates a new root certificate from the key file passed in.
func NewFromPEM(req *csr.CertificateRequest, keyFile string) (cert, csrPEM []byte, err error) {
privData, err := helpers.ReadBytes(keyFile)
if err != nil {
return nil, nil, err
}
priv, err := helpers.ParsePrivateKeyPEM(privData)
if err != nil {
return nil, nil, err
}
return NewFromSigner(req, priv)
}
// RenewFromPEM re-creates a root certificate from the CA cert and key
// files. The resulting root certificate will have the input CA certificate
// as the template and have the same expiry length. E.g. the existing CA
// is valid for a year from Jan 01 2015 to Jan 01 2016, the renewed certificate
// will be valid from now and expire in one year as well.
func RenewFromPEM(caFile, keyFile string) ([]byte, error) {
caBytes, err := helpers.ReadBytes(caFile)
if err != nil {
return nil, err
}
ca, err := helpers.ParseCertificatePEM(caBytes)
if err != nil {
return nil, err
}
keyBytes, err := helpers.ReadBytes(keyFile)
if err != nil {
return nil, err
}
key, err := helpers.ParsePrivateKeyPEM(keyBytes)
if err != nil {
return nil, err
}
return RenewFromSigner(ca, key)
}
// NewFromSigner creates a new root certificate from a crypto.Signer.
func NewFromSigner(req *csr.CertificateRequest, priv crypto.Signer) (cert, csrPEM []byte, err error) {
policy := CAPolicy()
if req.CA != nil {
if req.CA.Expiry != "" {
policy.Default.ExpiryString = req.CA.Expiry
policy.Default.Expiry, err = time.ParseDuration(req.CA.Expiry)
if err != nil {
return nil, nil, err
}
}
policy.Default.CAConstraint.MaxPathLen = req.CA.PathLength
if req.CA.PathLength != 0 && req.CA.PathLenZero == true {
log.Infof("ignore invalid 'pathlenzero' value")
} else {
policy.Default.CAConstraint.MaxPathLenZero = req.CA.PathLenZero
}
}
csrPEM, err = csr.Generate(priv, req)
if err != nil {
return nil, nil, err
}
s, err := local.NewSigner(priv, nil, signer.DefaultSigAlgo(priv), policy)
if err != nil {
log.Errorf("failed to create signer: %v", err)
return
}
signReq := signer.SignRequest{Request: string(csrPEM)}
cert, err = s.Sign(signReq)
return
}
// RenewFromSigner re-creates a root certificate from the CA cert and crypto.Signer.
// The resulting root certificate will have ca certificate
// as the template and have the same expiry length. E.g. the existing CA
// is valid for a year from Jan 01 2015 to Jan 01 2016, the renewed certificate
// will be valid from now and expire in one year as well.
func RenewFromSigner(ca *x509.Certificate, priv crypto.Signer) ([]byte, error) {
if !ca.IsCA {
return nil, errors.New("input certificate is not a CA cert")
}
// matching certificate public key vs private key
switch {
case ca.PublicKeyAlgorithm == x509.RSA:
var rsaPublicKey *rsa.PublicKey
var ok bool
if rsaPublicKey, ok = priv.Public().(*rsa.PublicKey); !ok {
return nil, cferr.New(cferr.PrivateKeyError, cferr.KeyMismatch)
}
if ca.PublicKey.(*rsa.PublicKey).N.Cmp(rsaPublicKey.N) != 0 {
return nil, cferr.New(cferr.PrivateKeyError, cferr.KeyMismatch)
}
case ca.PublicKeyAlgorithm == x509.ECDSA:
var ecdsaPublicKey *ecdsa.PublicKey
var ok bool
if ecdsaPublicKey, ok = priv.Public().(*ecdsa.PublicKey); !ok {
return nil, cferr.New(cferr.PrivateKeyError, cferr.KeyMismatch)
}
if ca.PublicKey.(*ecdsa.PublicKey).X.Cmp(ecdsaPublicKey.X) != 0 {
return nil, cferr.New(cferr.PrivateKeyError, cferr.KeyMismatch)
}
default:
return nil, cferr.New(cferr.PrivateKeyError, cferr.NotRSAOrECC)
}
req := csr.ExtractCertificateRequest(ca)
cert, _, err := NewFromSigner(req, priv)
return cert, err
}
// CAPolicy contains the CA issuing policy as default policy.
var CAPolicy = func() *config.Signing {
return &config.Signing{
Default: &config.SigningProfile{
Usage: []string{"cert sign", "crl sign"},
ExpiryString: "43800h",
Expiry: 5 * helpers.OneYear,
CAConstraint: config.CAConstraint{IsCA: true},
},
}
}
// Update copies the CA certificate, updates the NotBefore and
// NotAfter fields, and then re-signs the certificate.
func Update(ca *x509.Certificate, priv crypto.Signer) (cert []byte, err error) {
copy, err := x509.ParseCertificate(ca.Raw)
if err != nil {
return
}
validity := ca.NotAfter.Sub(ca.NotBefore)
copy.NotBefore = time.Now().Round(time.Minute).Add(-5 * time.Minute)
copy.NotAfter = copy.NotBefore.Add(validity)
cert, err = x509.CreateCertificate(rand.Reader, copy, copy, priv.Public(), priv)
if err != nil {
return
}
cert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})
return
}

View File

@ -1,162 +0,0 @@
// Package log implements a wrapper around the Go standard library's
// logging package. Clients should set the current log level; only
// messages below that level will actually be logged. For example, if
// Level is set to LevelWarning, only log messages at the Warning,
// Error, and Critical levels will be logged.
package log
import (
"fmt"
"log"
"os"
)
// The following constants represent logging levels in increasing levels of seriousness.
const (
// LevelDebug is the log level for Debug statements.
LevelDebug = iota
// LevelInfo is the log level for Info statements.
LevelInfo
// LevelWarning is the log level for Warning statements.
LevelWarning
// LevelError is the log level for Error statements.
LevelError
// LevelCritical is the log level for Critical statements.
LevelCritical
// LevelFatal is the log level for Fatal statements.
LevelFatal
)
var levelPrefix = [...]string{
LevelDebug: "DEBUG",
LevelInfo: "INFO",
LevelWarning: "WARNING",
LevelError: "ERROR",
LevelCritical: "CRITICAL",
LevelFatal: "FATAL",
}
// Level stores the current logging level.
var Level = LevelInfo
// SyslogWriter specifies the necessary methods for an alternate output
// destination passed in via SetLogger.
//
// SyslogWriter is satisfied by *syslog.Writer.
type SyslogWriter interface {
Debug(string)
Info(string)
Warning(string)
Err(string)
Crit(string)
Emerg(string)
}
// syslogWriter stores the SetLogger() parameter.
var syslogWriter SyslogWriter
// SetLogger sets the output used for output by this package.
// A *syslog.Writer is a good choice for the logger parameter.
// Call with a nil parameter to revert to default behavior.
func SetLogger(logger SyslogWriter) {
syslogWriter = logger
}
func print(l int, msg string) {
if l >= Level {
if syslogWriter != nil {
switch l {
case LevelDebug:
syslogWriter.Debug(msg)
case LevelInfo:
syslogWriter.Info(msg)
case LevelWarning:
syslogWriter.Warning(msg)
case LevelError:
syslogWriter.Err(msg)
case LevelCritical:
syslogWriter.Crit(msg)
case LevelFatal:
syslogWriter.Emerg(msg)
}
} else {
log.Printf("[%s] %s", levelPrefix[l], msg)
}
}
}
func outputf(l int, format string, v []interface{}) {
print(l, fmt.Sprintf(format, v...))
}
func output(l int, v []interface{}) {
print(l, fmt.Sprint(v...))
}
// Fatalf logs a formatted message at the "fatal" level and then exits. The
// arguments are handled in the same manner as fmt.Printf.
func Fatalf(format string, v ...interface{}) {
outputf(LevelFatal, format, v)
os.Exit(1)
}
// Fatal logs its arguments at the "fatal" level and then exits.
func Fatal(v ...interface{}) {
output(LevelFatal, v)
os.Exit(1)
}
// Criticalf logs a formatted message at the "critical" level. The
// arguments are handled in the same manner as fmt.Printf.
func Criticalf(format string, v ...interface{}) {
outputf(LevelCritical, format, v)
}
// Critical logs its arguments at the "critical" level.
func Critical(v ...interface{}) {
output(LevelCritical, v)
}
// Errorf logs a formatted message at the "error" level. The arguments
// are handled in the same manner as fmt.Printf.
func Errorf(format string, v ...interface{}) {
outputf(LevelError, format, v)
}
// Error logs its arguments at the "error" level.
func Error(v ...interface{}) {
output(LevelError, v)
}
// Warningf logs a formatted message at the "warning" level. The
// arguments are handled in the same manner as fmt.Printf.
func Warningf(format string, v ...interface{}) {
outputf(LevelWarning, format, v)
}
// Warning logs its arguments at the "warning" level.
func Warning(v ...interface{}) {
output(LevelWarning, v)
}
// Infof logs a formatted message at the "info" level. The arguments
// are handled in the same manner as fmt.Printf.
func Infof(format string, v ...interface{}) {
outputf(LevelInfo, format, v)
}
// Info logs its arguments at the "info" level.
func Info(v ...interface{}) {
output(LevelInfo, v)
}
// Debugf logs a formatted message at the "debug" level. The arguments
// are handled in the same manner as fmt.Printf.
func Debugf(format string, v ...interface{}) {
outputf(LevelDebug, format, v)
}
// Debug logs its arguments at the "debug" level.
func Debug(v ...interface{}) {
output(LevelDebug, v)
}

View File

@ -1,13 +0,0 @@
// Package config in the ocsp directory provides configuration data for an OCSP
// signer.
package config
import "time"
// Config contains configuration information required to set up an OCSP signer.
type Config struct {
CACertFile string
ResponderCertFile string
KeyFile string
Interval time.Duration
}

View File

@ -1,569 +0,0 @@
// Package local implements certificate signature functionality for CFSSL.
package local
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"net/http"
"net/mail"
"net/url"
"os"
"github.com/cloudflare/cfssl/certdb"
"github.com/cloudflare/cfssl/config"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/helpers"
"github.com/cloudflare/cfssl/info"
"github.com/cloudflare/cfssl/log"
"github.com/cloudflare/cfssl/signer"
"github.com/google/certificate-transparency-go"
"github.com/google/certificate-transparency-go/client"
"github.com/google/certificate-transparency-go/jsonclient"
"golang.org/x/net/context"
)
// Signer contains a signer that uses the standard library to
// support both ECDSA and RSA CA keys.
type Signer struct {
ca *x509.Certificate
priv crypto.Signer
policy *config.Signing
sigAlgo x509.SignatureAlgorithm
dbAccessor certdb.Accessor
}
// NewSigner creates a new Signer directly from a
// private key and certificate, with optional policy.
func NewSigner(priv crypto.Signer, cert *x509.Certificate, sigAlgo x509.SignatureAlgorithm, policy *config.Signing) (*Signer, error) {
if policy == nil {
policy = &config.Signing{
Profiles: map[string]*config.SigningProfile{},
Default: config.DefaultConfig()}
}
if !policy.Valid() {
return nil, cferr.New(cferr.PolicyError, cferr.InvalidPolicy)
}
return &Signer{
ca: cert,
priv: priv,
sigAlgo: sigAlgo,
policy: policy,
}, nil
}
// NewSignerFromFile generates a new local signer from a caFile
// and a caKey file, both PEM encoded.
func NewSignerFromFile(caFile, caKeyFile string, policy *config.Signing) (*Signer, error) {
log.Debug("Loading CA: ", caFile)
ca, err := helpers.ReadBytes(caFile)
if err != nil {
return nil, err
}
log.Debug("Loading CA key: ", caKeyFile)
cakey, err := helpers.ReadBytes(caKeyFile)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ReadFailed, err)
}
parsedCa, err := helpers.ParseCertificatePEM(ca)
if err != nil {
return nil, err
}
strPassword := os.Getenv("CFSSL_CA_PK_PASSWORD")
password := []byte(strPassword)
if strPassword == "" {
password = nil
}
priv, err := helpers.ParsePrivateKeyPEMWithPassword(cakey, password)
if err != nil {
log.Debug("Malformed private key %v", err)
return nil, err
}
return NewSigner(priv, parsedCa, signer.DefaultSigAlgo(priv), policy)
}
func (s *Signer) sign(template *x509.Certificate) (cert []byte, err error) {
var initRoot bool
if s.ca == nil {
if !template.IsCA {
err = cferr.New(cferr.PolicyError, cferr.InvalidRequest)
return
}
template.DNSNames = nil
template.EmailAddresses = nil
template.URIs = nil
s.ca = template
initRoot = true
}
derBytes, err := x509.CreateCertificate(rand.Reader, template, s.ca, template.PublicKey, s.priv)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.Unknown, err)
}
if initRoot {
s.ca, err = x509.ParseCertificate(derBytes)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.ParseFailed, err)
}
}
cert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
log.Infof("signed certificate with serial number %d", template.SerialNumber)
return
}
// replaceSliceIfEmpty replaces the contents of replaced with newContents if
// the slice referenced by replaced is empty
func replaceSliceIfEmpty(replaced, newContents *[]string) {
if len(*replaced) == 0 {
*replaced = *newContents
}
}
// PopulateSubjectFromCSR has functionality similar to Name, except
// it fills the fields of the resulting pkix.Name with req's if the
// subject's corresponding fields are empty
func PopulateSubjectFromCSR(s *signer.Subject, req pkix.Name) pkix.Name {
// if no subject, use req
if s == nil {
return req
}
name := s.Name()
if name.CommonName == "" {
name.CommonName = req.CommonName
}
replaceSliceIfEmpty(&name.Country, &req.Country)
replaceSliceIfEmpty(&name.Province, &req.Province)
replaceSliceIfEmpty(&name.Locality, &req.Locality)
replaceSliceIfEmpty(&name.Organization, &req.Organization)
replaceSliceIfEmpty(&name.OrganizationalUnit, &req.OrganizationalUnit)
if name.SerialNumber == "" {
name.SerialNumber = req.SerialNumber
}
return name
}
// OverrideHosts fills template's IPAddresses, EmailAddresses, DNSNames, and URIs with the
// content of hosts, if it is not nil.
func OverrideHosts(template *x509.Certificate, hosts []string) {
if hosts != nil {
template.IPAddresses = []net.IP{}
template.EmailAddresses = []string{}
template.DNSNames = []string{}
template.URIs = []*url.URL{}
}
for i := range hosts {
if ip := net.ParseIP(hosts[i]); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else if email, err := mail.ParseAddress(hosts[i]); err == nil && email != nil {
template.EmailAddresses = append(template.EmailAddresses, email.Address)
} else if uri, err := url.ParseRequestURI(hosts[i]); err == nil && uri != nil {
template.URIs = append(template.URIs, uri)
} else {
template.DNSNames = append(template.DNSNames, hosts[i])
}
}
}
// Sign signs a new certificate based on the PEM-encoded client
// certificate or certificate request with the signing profile,
// specified by profileName.
func (s *Signer) Sign(req signer.SignRequest) (cert []byte, err error) {
profile, err := signer.Profile(s, req.Profile)
if err != nil {
return
}
block, _ := pem.Decode([]byte(req.Request))
if block == nil {
return nil, cferr.New(cferr.CSRError, cferr.DecodeFailed)
}
if block.Type != "NEW CERTIFICATE REQUEST" && block.Type != "CERTIFICATE REQUEST" {
return nil, cferr.Wrap(cferr.CSRError,
cferr.BadRequest, errors.New("not a csr"))
}
csrTemplate, err := signer.ParseCertificateRequest(s, block.Bytes)
if err != nil {
return nil, err
}
// Copy out only the fields from the CSR authorized by policy.
safeTemplate := x509.Certificate{}
// If the profile contains no explicit whitelist, assume that all fields
// should be copied from the CSR.
if profile.CSRWhitelist == nil {
safeTemplate = *csrTemplate
} else {
if profile.CSRWhitelist.Subject {
safeTemplate.Subject = csrTemplate.Subject
}
if profile.CSRWhitelist.PublicKeyAlgorithm {
safeTemplate.PublicKeyAlgorithm = csrTemplate.PublicKeyAlgorithm
}
if profile.CSRWhitelist.PublicKey {
safeTemplate.PublicKey = csrTemplate.PublicKey
}
if profile.CSRWhitelist.SignatureAlgorithm {
safeTemplate.SignatureAlgorithm = csrTemplate.SignatureAlgorithm
}
if profile.CSRWhitelist.DNSNames {
safeTemplate.DNSNames = csrTemplate.DNSNames
}
if profile.CSRWhitelist.IPAddresses {
safeTemplate.IPAddresses = csrTemplate.IPAddresses
}
if profile.CSRWhitelist.EmailAddresses {
safeTemplate.EmailAddresses = csrTemplate.EmailAddresses
}
if profile.CSRWhitelist.URIs {
safeTemplate.URIs = csrTemplate.URIs
}
}
if req.CRLOverride != "" {
safeTemplate.CRLDistributionPoints = []string{req.CRLOverride}
}
if safeTemplate.IsCA {
if !profile.CAConstraint.IsCA {
log.Error("local signer policy disallows issuing CA certificate")
return nil, cferr.New(cferr.PolicyError, cferr.InvalidRequest)
}
if s.ca != nil && s.ca.MaxPathLen > 0 {
if safeTemplate.MaxPathLen >= s.ca.MaxPathLen {
log.Error("local signer certificate disallows CA MaxPathLen extending")
// do not sign a cert with pathlen > current
return nil, cferr.New(cferr.PolicyError, cferr.InvalidRequest)
}
} else if s.ca != nil && s.ca.MaxPathLen == 0 && s.ca.MaxPathLenZero {
log.Error("local signer certificate disallows issuing CA certificate")
// signer has pathlen of 0, do not sign more intermediate CAs
return nil, cferr.New(cferr.PolicyError, cferr.InvalidRequest)
}
}
OverrideHosts(&safeTemplate, req.Hosts)
safeTemplate.Subject = PopulateSubjectFromCSR(req.Subject, safeTemplate.Subject)
// If there is a whitelist, ensure that both the Common Name and SAN DNSNames match
if profile.NameWhitelist != nil {
if safeTemplate.Subject.CommonName != "" {
if profile.NameWhitelist.Find([]byte(safeTemplate.Subject.CommonName)) == nil {
return nil, cferr.New(cferr.PolicyError, cferr.UnmatchedWhitelist)
}
}
for _, name := range safeTemplate.DNSNames {
if profile.NameWhitelist.Find([]byte(name)) == nil {
return nil, cferr.New(cferr.PolicyError, cferr.UnmatchedWhitelist)
}
}
for _, name := range safeTemplate.EmailAddresses {
if profile.NameWhitelist.Find([]byte(name)) == nil {
return nil, cferr.New(cferr.PolicyError, cferr.UnmatchedWhitelist)
}
}
for _, name := range safeTemplate.URIs {
if profile.NameWhitelist.Find([]byte(name.String())) == nil {
return nil, cferr.New(cferr.PolicyError, cferr.UnmatchedWhitelist)
}
}
}
if profile.ClientProvidesSerialNumbers {
if req.Serial == nil {
return nil, cferr.New(cferr.CertificateError, cferr.MissingSerial)
}
safeTemplate.SerialNumber = req.Serial
} else {
// RFC 5280 4.1.2.2:
// Certificate users MUST be able to handle serialNumber
// values up to 20 octets. Conforming CAs MUST NOT use
// serialNumber values longer than 20 octets.
//
// If CFSSL is providing the serial numbers, it makes
// sense to use the max supported size.
serialNumber := make([]byte, 20)
_, err = io.ReadFull(rand.Reader, serialNumber)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.Unknown, err)
}
// SetBytes interprets buf as the bytes of a big-endian
// unsigned integer. The leading byte should be masked
// off to ensure it isn't negative.
serialNumber[0] &= 0x7F
safeTemplate.SerialNumber = new(big.Int).SetBytes(serialNumber)
}
if len(req.Extensions) > 0 {
for _, ext := range req.Extensions {
oid := asn1.ObjectIdentifier(ext.ID)
if !profile.ExtensionWhitelist[oid.String()] {
return nil, cferr.New(cferr.CertificateError, cferr.InvalidRequest)
}
rawValue, err := hex.DecodeString(ext.Value)
if err != nil {
return nil, cferr.Wrap(cferr.CertificateError, cferr.InvalidRequest, err)
}
safeTemplate.ExtraExtensions = append(safeTemplate.ExtraExtensions, pkix.Extension{
Id: oid,
Critical: ext.Critical,
Value: rawValue,
})
}
}
var distPoints = safeTemplate.CRLDistributionPoints
err = signer.FillTemplate(&safeTemplate, s.policy.Default, profile, req.NotBefore, req.NotAfter)
if err != nil {
return nil, err
}
if distPoints != nil && len(distPoints) > 0 {
safeTemplate.CRLDistributionPoints = distPoints
}
var certTBS = safeTemplate
if len(profile.CTLogServers) > 0 || req.ReturnPrecert {
// Add a poison extension which prevents validation
var poisonExtension = pkix.Extension{Id: signer.CTPoisonOID, Critical: true, Value: []byte{0x05, 0x00}}
var poisonedPreCert = certTBS
poisonedPreCert.ExtraExtensions = append(safeTemplate.ExtraExtensions, poisonExtension)
cert, err = s.sign(&poisonedPreCert)
if err != nil {
return
}
if req.ReturnPrecert {
return cert, nil
}
derCert, _ := pem.Decode(cert)
prechain := []ct.ASN1Cert{{Data: derCert.Bytes}, {Data: s.ca.Raw}}
var sctList []ct.SignedCertificateTimestamp
for _, server := range profile.CTLogServers {
log.Infof("submitting poisoned precertificate to %s", server)
ctclient, err := client.New(server, nil, jsonclient.Options{})
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.PrecertSubmissionFailed, err)
}
var resp *ct.SignedCertificateTimestamp
ctx := context.Background()
resp, err = ctclient.AddPreChain(ctx, prechain)
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.PrecertSubmissionFailed, err)
}
sctList = append(sctList, *resp)
}
var serializedSCTList []byte
serializedSCTList, err = helpers.SerializeSCTList(sctList)
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, err)
}
// Serialize again as an octet string before embedding
serializedSCTList, err = asn1.Marshal(serializedSCTList)
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, err)
}
var SCTListExtension = pkix.Extension{Id: signer.SCTListOID, Critical: false, Value: serializedSCTList}
certTBS.ExtraExtensions = append(certTBS.ExtraExtensions, SCTListExtension)
}
var signedCert []byte
signedCert, err = s.sign(&certTBS)
if err != nil {
return nil, err
}
// Get the AKI from signedCert. This is required to support Go 1.9+.
// In prior versions of Go, x509.CreateCertificate updated the
// AuthorityKeyId of certTBS.
parsedCert, _ := helpers.ParseCertificatePEM(signedCert)
if s.dbAccessor != nil {
var certRecord = certdb.CertificateRecord{
Serial: certTBS.SerialNumber.String(),
// this relies on the specific behavior of x509.CreateCertificate
// which sets the AuthorityKeyId from the signer's SubjectKeyId
AKI: hex.EncodeToString(parsedCert.AuthorityKeyId),
CALabel: req.Label,
Status: "good",
Expiry: certTBS.NotAfter,
PEM: string(signedCert),
}
err = s.dbAccessor.InsertCertificate(certRecord)
if err != nil {
return nil, err
}
log.Debug("saved certificate with serial number ", certTBS.SerialNumber)
}
return signedCert, nil
}
// SignFromPrecert creates and signs a certificate from an existing precertificate
// that was previously signed by Signer.ca and inserts the provided SCTs into the
// new certificate. The resulting certificate will be a exact copy of the precert
// except for the removal of the poison extension and the addition of the SCT list
// extension. SignFromPrecert does not verify that the contents of the certificate
// still match the signing profile of the signer, it only requires that the precert
// was previously signed by the Signers CA.
func (s *Signer) SignFromPrecert(precert *x509.Certificate, scts []ct.SignedCertificateTimestamp) ([]byte, error) {
// Verify certificate was signed by s.ca
if err := precert.CheckSignatureFrom(s.ca); err != nil {
return nil, err
}
// Verify certificate is a precert
isPrecert := false
poisonIndex := 0
for i, ext := range precert.Extensions {
if ext.Id.Equal(signer.CTPoisonOID) {
if !ext.Critical {
return nil, cferr.New(cferr.CTError, cferr.PrecertInvalidPoison)
}
// Check extension contains ASN.1 NULL
if bytes.Compare(ext.Value, []byte{0x05, 0x00}) != 0 {
return nil, cferr.New(cferr.CTError, cferr.PrecertInvalidPoison)
}
isPrecert = true
poisonIndex = i
break
}
}
if !isPrecert {
return nil, cferr.New(cferr.CTError, cferr.PrecertMissingPoison)
}
// Serialize SCTs into list format and create extension
serializedList, err := helpers.SerializeSCTList(scts)
if err != nil {
return nil, err
}
// Serialize again as an octet string before embedding
serializedList, err = asn1.Marshal(serializedList)
if err != nil {
return nil, cferr.Wrap(cferr.CTError, cferr.Unknown, err)
}
sctExt := pkix.Extension{Id: signer.SCTListOID, Critical: false, Value: serializedList}
// Create the new tbsCert from precert. Do explicit copies of any slices so that we don't
// use memory that may be altered by us or the caller at a later stage.
tbsCert := x509.Certificate{
SignatureAlgorithm: precert.SignatureAlgorithm,
PublicKeyAlgorithm: precert.PublicKeyAlgorithm,
PublicKey: precert.PublicKey,
Version: precert.Version,
SerialNumber: precert.SerialNumber,
Issuer: precert.Issuer,
Subject: precert.Subject,
NotBefore: precert.NotBefore,
NotAfter: precert.NotAfter,
KeyUsage: precert.KeyUsage,
BasicConstraintsValid: precert.BasicConstraintsValid,
IsCA: precert.IsCA,
MaxPathLen: precert.MaxPathLen,
MaxPathLenZero: precert.MaxPathLenZero,
PermittedDNSDomainsCritical: precert.PermittedDNSDomainsCritical,
}
if len(precert.Extensions) > 0 {
tbsCert.ExtraExtensions = make([]pkix.Extension, len(precert.Extensions))
copy(tbsCert.ExtraExtensions, precert.Extensions)
}
// Remove the poison extension from ExtraExtensions
tbsCert.ExtraExtensions = append(tbsCert.ExtraExtensions[:poisonIndex], tbsCert.ExtraExtensions[poisonIndex+1:]...)
// Insert the SCT list extension
tbsCert.ExtraExtensions = append(tbsCert.ExtraExtensions, sctExt)
// Sign the tbsCert
return s.sign(&tbsCert)
}
// Info return a populated info.Resp struct or an error.
func (s *Signer) Info(req info.Req) (resp *info.Resp, err error) {
cert, err := s.Certificate(req.Label, req.Profile)
if err != nil {
return
}
profile, err := signer.Profile(s, req.Profile)
if err != nil {
return
}
resp = new(info.Resp)
if cert.Raw != nil {
resp.Certificate = string(bytes.TrimSpace(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})))
}
resp.Usage = profile.Usage
resp.ExpiryString = profile.ExpiryString
return
}
// SigAlgo returns the RSA signer's signature algorithm.
func (s *Signer) SigAlgo() x509.SignatureAlgorithm {
return s.sigAlgo
}
// Certificate returns the signer's certificate.
func (s *Signer) Certificate(label, profile string) (*x509.Certificate, error) {
cert := *s.ca
return &cert, nil
}
// SetPolicy sets the signer's signature policy.
func (s *Signer) SetPolicy(policy *config.Signing) {
s.policy = policy
}
// SetDBAccessor sets the signers' cert db accessor
func (s *Signer) SetDBAccessor(dba certdb.Accessor) {
s.dbAccessor = dba
}
// GetDBAccessor returns the signers' cert db accessor
func (s *Signer) GetDBAccessor() certdb.Accessor {
return s.dbAccessor
}
// SetReqModifier does nothing for local
func (s *Signer) SetReqModifier(func(*http.Request, []byte)) {
// noop
}
// Policy returns the signer's policy.
func (s *Signer) Policy() *config.Signing {
return s.policy
}

View File

@ -1,438 +0,0 @@
// Package signer implements certificate signature functionality for CFSSL.
package signer
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"math/big"
"net/http"
"strings"
"time"
"github.com/cloudflare/cfssl/certdb"
"github.com/cloudflare/cfssl/config"
"github.com/cloudflare/cfssl/csr"
cferr "github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/info"
)
// Subject contains the information that should be used to override the
// subject information when signing a certificate.
type Subject struct {
CN string
Names []csr.Name `json:"names"`
SerialNumber string
}
// Extension represents a raw extension to be included in the certificate. The
// "value" field must be hex encoded.
type Extension struct {
ID config.OID `json:"id"`
Critical bool `json:"critical"`
Value string `json:"value"`
}
// SignRequest stores a signature request, which contains the hostname,
// the CSR, optional subject information, and the signature profile.
//
// Extensions provided in the signRequest are copied into the certificate, as
// long as they are in the ExtensionWhitelist for the signer's policy.
// Extensions requested in the CSR are ignored, except for those processed by
// ParseCertificateRequest (mainly subjectAltName).
type SignRequest struct {
Hosts []string `json:"hosts"`
Request string `json:"certificate_request"`
Subject *Subject `json:"subject,omitempty"`
Profile string `json:"profile"`
CRLOverride string `json:"crl_override"`
Label string `json:"label"`
Serial *big.Int `json:"serial,omitempty"`
Extensions []Extension `json:"extensions,omitempty"`
// If provided, NotBefore will be used without modification (except
// for canonicalization) as the value of the notBefore field of the
// certificate. In particular no backdating adjustment will be made
// when NotBefore is provided.
NotBefore time.Time
// If provided, NotAfter will be used without modification (except
// for canonicalization) as the value of the notAfter field of the
// certificate.
NotAfter time.Time
// If ReturnPrecert is true a certificate with the CT poison extension
// will be returned from the Signer instead of attempting to retrieve
// SCTs and populate the tbsCert with them itself. This precert can then
// be passed to SignFromPrecert with the SCTs in order to create a
// valid certificate.
ReturnPrecert bool
}
// appendIf appends to a if s is not an empty string.
func appendIf(s string, a *[]string) {
if s != "" {
*a = append(*a, s)
}
}
// Name returns the PKIX name for the subject.
func (s *Subject) Name() pkix.Name {
var name pkix.Name
name.CommonName = s.CN
for _, n := range s.Names {
appendIf(n.C, &name.Country)
appendIf(n.ST, &name.Province)
appendIf(n.L, &name.Locality)
appendIf(n.O, &name.Organization)
appendIf(n.OU, &name.OrganizationalUnit)
}
name.SerialNumber = s.SerialNumber
return name
}
// SplitHosts takes a comma-spearated list of hosts and returns a slice
// with the hosts split
func SplitHosts(hostList string) []string {
if hostList == "" {
return nil
}
return strings.Split(hostList, ",")
}
// A Signer contains a CA's certificate and private key for signing
// certificates, a Signing policy to refer to and a SignatureAlgorithm.
type Signer interface {
Info(info.Req) (*info.Resp, error)
Policy() *config.Signing
SetDBAccessor(certdb.Accessor)
GetDBAccessor() certdb.Accessor
SetPolicy(*config.Signing)
SigAlgo() x509.SignatureAlgorithm
Sign(req SignRequest) (cert []byte, err error)
SetReqModifier(func(*http.Request, []byte))
}
// Profile gets the specific profile from the signer
func Profile(s Signer, profile string) (*config.SigningProfile, error) {
var p *config.SigningProfile
policy := s.Policy()
if policy != nil && policy.Profiles != nil && profile != "" {
p = policy.Profiles[profile]
}
if p == nil && policy != nil {
p = policy.Default
}
if p == nil {
return nil, cferr.Wrap(cferr.APIClientError, cferr.ClientHTTPError, errors.New("profile must not be nil"))
}
return p, nil
}
// DefaultSigAlgo returns an appropriate X.509 signature algorithm given
// the CA's private key.
func DefaultSigAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
pub := priv.Public()
switch pub := pub.(type) {
case *rsa.PublicKey:
keySize := pub.N.BitLen()
switch {
case keySize >= 4096:
return x509.SHA512WithRSA
case keySize >= 3072:
return x509.SHA384WithRSA
case keySize >= 2048:
return x509.SHA256WithRSA
default:
return x509.SHA1WithRSA
}
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return x509.ECDSAWithSHA256
case elliptic.P384():
return x509.ECDSAWithSHA384
case elliptic.P521():
return x509.ECDSAWithSHA512
default:
return x509.ECDSAWithSHA1
}
default:
return x509.UnknownSignatureAlgorithm
}
}
// ParseCertificateRequest takes an incoming certificate request and
// builds a certificate template from it.
func ParseCertificateRequest(s Signer, csrBytes []byte) (template *x509.Certificate, err error) {
csrv, err := x509.ParseCertificateRequest(csrBytes)
if err != nil {
err = cferr.Wrap(cferr.CSRError, cferr.ParseFailed, err)
return
}
err = csrv.CheckSignature()
if err != nil {
err = cferr.Wrap(cferr.CSRError, cferr.KeyMismatch, err)
return
}
template = &x509.Certificate{
Subject: csrv.Subject,
PublicKeyAlgorithm: csrv.PublicKeyAlgorithm,
PublicKey: csrv.PublicKey,
SignatureAlgorithm: s.SigAlgo(),
DNSNames: csrv.DNSNames,
IPAddresses: csrv.IPAddresses,
EmailAddresses: csrv.EmailAddresses,
URIs: csrv.URIs,
}
for _, val := range csrv.Extensions {
// Check the CSR for the X.509 BasicConstraints (RFC 5280, 4.2.1.9)
// extension and append to template if necessary
if val.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 19}) {
var constraints csr.BasicConstraints
var rest []byte
if rest, err = asn1.Unmarshal(val.Value, &constraints); err != nil {
return nil, cferr.Wrap(cferr.CSRError, cferr.ParseFailed, err)
} else if len(rest) != 0 {
return nil, cferr.Wrap(cferr.CSRError, cferr.ParseFailed, errors.New("x509: trailing data after X.509 BasicConstraints"))
}
template.BasicConstraintsValid = true
template.IsCA = constraints.IsCA
template.MaxPathLen = constraints.MaxPathLen
template.MaxPathLenZero = template.MaxPathLen == 0
}
}
return
}
type subjectPublicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}
// ComputeSKI derives an SKI from the certificate's public key in a
// standard manner. This is done by computing the SHA-1 digest of the
// SubjectPublicKeyInfo component of the certificate.
func ComputeSKI(template *x509.Certificate) ([]byte, error) {
pub := template.PublicKey
encodedPub, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
return nil, err
}
var subPKI subjectPublicKeyInfo
_, err = asn1.Unmarshal(encodedPub, &subPKI)
if err != nil {
return nil, err
}
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes)
return pubHash[:], nil
}
// FillTemplate is a utility function that tries to load as much of
// the certificate template as possible from the profiles and current
// template. It fills in the key uses, expiration, revocation URLs
// and SKI.
func FillTemplate(template *x509.Certificate, defaultProfile, profile *config.SigningProfile, notBefore time.Time, notAfter time.Time) error {
ski, err := ComputeSKI(template)
if err != nil {
return err
}
var (
eku []x509.ExtKeyUsage
ku x509.KeyUsage
backdate time.Duration
expiry time.Duration
crlURL, ocspURL string
issuerURL = profile.IssuerURL
)
// The third value returned from Usages is a list of unknown key usages.
// This should be used when validating the profile at load, and isn't used
// here.
ku, eku, _ = profile.Usages()
if profile.IssuerURL == nil {
issuerURL = defaultProfile.IssuerURL
}
if ku == 0 && len(eku) == 0 {
return cferr.New(cferr.PolicyError, cferr.NoKeyUsages)
}
if expiry = profile.Expiry; expiry == 0 {
expiry = defaultProfile.Expiry
}
if crlURL = profile.CRL; crlURL == "" {
crlURL = defaultProfile.CRL
}
if ocspURL = profile.OCSP; ocspURL == "" {
ocspURL = defaultProfile.OCSP
}
if notBefore.IsZero() {
if !profile.NotBefore.IsZero() {
notBefore = profile.NotBefore
} else {
if backdate = profile.Backdate; backdate == 0 {
backdate = -5 * time.Minute
} else {
backdate = -1 * profile.Backdate
}
notBefore = time.Now().Round(time.Minute).Add(backdate)
}
}
notBefore = notBefore.UTC()
if notAfter.IsZero() {
if !profile.NotAfter.IsZero() {
notAfter = profile.NotAfter
} else {
notAfter = notBefore.Add(expiry)
}
}
notAfter = notAfter.UTC()
template.NotBefore = notBefore
template.NotAfter = notAfter
template.KeyUsage = ku
template.ExtKeyUsage = eku
template.BasicConstraintsValid = true
template.IsCA = profile.CAConstraint.IsCA
if template.IsCA {
template.MaxPathLen = profile.CAConstraint.MaxPathLen
if template.MaxPathLen == 0 {
template.MaxPathLenZero = profile.CAConstraint.MaxPathLenZero
}
template.DNSNames = nil
template.EmailAddresses = nil
template.URIs = nil
}
template.SubjectKeyId = ski
if ocspURL != "" {
template.OCSPServer = []string{ocspURL}
}
if crlURL != "" {
template.CRLDistributionPoints = []string{crlURL}
}
if len(issuerURL) != 0 {
template.IssuingCertificateURL = issuerURL
}
if len(profile.Policies) != 0 {
err = addPolicies(template, profile.Policies)
if err != nil {
return cferr.Wrap(cferr.PolicyError, cferr.InvalidPolicy, err)
}
}
if profile.OCSPNoCheck {
ocspNoCheckExtension := pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1, 5},
Critical: false,
Value: []byte{0x05, 0x00},
}
template.ExtraExtensions = append(template.ExtraExtensions, ocspNoCheckExtension)
}
return nil
}
type policyInformation struct {
PolicyIdentifier asn1.ObjectIdentifier
Qualifiers []interface{} `asn1:"tag:optional,omitempty"`
}
type cpsPolicyQualifier struct {
PolicyQualifierID asn1.ObjectIdentifier
Qualifier string `asn1:"tag:optional,ia5"`
}
type userNotice struct {
ExplicitText string `asn1:"tag:optional,utf8"`
}
type userNoticePolicyQualifier struct {
PolicyQualifierID asn1.ObjectIdentifier
Qualifier userNotice
}
var (
// Per https://tools.ietf.org/html/rfc3280.html#page-106, this represents:
// iso(1) identified-organization(3) dod(6) internet(1) security(5)
// mechanisms(5) pkix(7) id-qt(2) id-qt-cps(1)
iDQTCertificationPracticeStatement = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 2, 1}
// iso(1) identified-organization(3) dod(6) internet(1) security(5)
// mechanisms(5) pkix(7) id-qt(2) id-qt-unotice(2)
iDQTUserNotice = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 2, 2}
// CTPoisonOID is the object ID of the critical poison extension for precertificates
// https://tools.ietf.org/html/rfc6962#page-9
CTPoisonOID = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 3}
// SCTListOID is the object ID for the Signed Certificate Timestamp certificate extension
// https://tools.ietf.org/html/rfc6962#page-14
SCTListOID = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
)
// addPolicies adds Certificate Policies and optional Policy Qualifiers to a
// certificate, based on the input config. Go's x509 library allows setting
// Certificate Policies easily, but does not support nested Policy Qualifiers
// under those policies. So we need to construct the ASN.1 structure ourselves.
func addPolicies(template *x509.Certificate, policies []config.CertificatePolicy) error {
asn1PolicyList := []policyInformation{}
for _, policy := range policies {
pi := policyInformation{
// The PolicyIdentifier is an OID assigned to a given issuer.
PolicyIdentifier: asn1.ObjectIdentifier(policy.ID),
}
for _, qualifier := range policy.Qualifiers {
switch qualifier.Type {
case "id-qt-unotice":
pi.Qualifiers = append(pi.Qualifiers,
userNoticePolicyQualifier{
PolicyQualifierID: iDQTUserNotice,
Qualifier: userNotice{
ExplicitText: qualifier.Value,
},
})
case "id-qt-cps":
pi.Qualifiers = append(pi.Qualifiers,
cpsPolicyQualifier{
PolicyQualifierID: iDQTCertificationPracticeStatement,
Qualifier: qualifier.Value,
})
default:
return errors.New("Invalid qualifier type in Policies " + qualifier.Type)
}
}
asn1PolicyList = append(asn1PolicyList, pi)
}
asn1Bytes, err := asn1.Marshal(asn1PolicyList)
if err != nil {
return err
}
template.ExtraExtensions = append(template.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{2, 5, 29, 32},
Critical: false,
Value: asn1Bytes,
})
return nil
}

View File

@ -1 +0,0 @@
examples/examples

View File

@ -1,4 +0,0 @@
language: go
go:
- 1.x

View File

@ -1,16 +0,0 @@
# changes to the go-restful-openapi package
## v1.0.0
- Fix for #19 MapModelTypeNameFunc has incomplete behavior
- prevent array param.Type be overwritten in the else case below (#47)
- Merge paths with existing paths from other webServices (#48)
## v0.11.0
- Register pointer to array/slice of primitives as such rather than as reference to the primitive type definition. (#46)
- Add support for map types using "additional properties" (#44)
## <= v0.10.0
See `git log`.

View File

@ -1,22 +0,0 @@
Copyright (c) 2017 Ernest Micklei
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,26 +0,0 @@
# go-restful-openapi
[![Build Status](https://travis-ci.org/emicklei/go-restful-openapi.png)](https://travis-ci.org/emicklei/go-restful-openapi)
[![GoDoc](https://godoc.org/github.com/emicklei/go-restful-openapi?status.svg)](https://godoc.org/github.com/emicklei/go-restful-openapi)
[openapi](https://www.openapis.org) extension to the go-restful package, targeting [version 2.0](https://github.com/OAI/OpenAPI-Specification)
## The following Go field tags are translated to OpenAPI equivalents
- description
- minimum
- maximum
- optional ( if set to "true" then it is not listed in `required`)
- unique
- modelDescription
- type (overrides the Go type String())
- enum
- readOnly
See TestThatExtraTagsAreReadIntoModel for examples.
## dependencies
- [go-restful](https://github.com/emicklei/go-restful)
- [go-openapi](https://github.com/go-openapi/spec)
© 2018, ernestmicklei.com. MIT License. Contributions welcome.

View File

@ -1,32 +0,0 @@
package restfulspec
import (
"reflect"
restful "github.com/emicklei/go-restful"
"github.com/go-openapi/spec"
)
func buildDefinitions(ws *restful.WebService, cfg Config) (definitions spec.Definitions) {
definitions = spec.Definitions{}
for _, each := range ws.Routes() {
addDefinitionsFromRouteTo(each, cfg, definitions)
}
return
}
func addDefinitionsFromRouteTo(r restful.Route, cfg Config, d spec.Definitions) {
builder := definitionBuilder{Definitions: d, Config: cfg}
if r.ReadSample != nil {
builder.addModel(reflect.TypeOf(r.ReadSample), "")
}
if r.WriteSample != nil {
builder.addModel(reflect.TypeOf(r.WriteSample), "")
}
for _, v := range r.ResponseErrors {
if v.Model == nil {
continue
}
builder.addModel(reflect.TypeOf(v.Model), "")
}
}

View File

@ -1,254 +0,0 @@
package restfulspec
import (
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
restful "github.com/emicklei/go-restful"
"github.com/go-openapi/spec"
)
// KeyOpenAPITags is a Metadata key for a restful Route
const KeyOpenAPITags = "openapi.tags"
func buildPaths(ws *restful.WebService, cfg Config) spec.Paths {
p := spec.Paths{Paths: map[string]spec.PathItem{}}
for _, each := range ws.Routes() {
path, patterns := sanitizePath(each.Path)
existingPathItem, ok := p.Paths[path]
if !ok {
existingPathItem = spec.PathItem{}
}
p.Paths[path] = buildPathItem(ws, each, existingPathItem, patterns, cfg)
}
return p
}
// sanitizePath removes regex expressions from named path params,
// since openapi only supports setting the pattern as a a property named "pattern".
// Expressions like "/api/v1/{name:[a-z]/" are converted to "/api/v1/{name}/".
// The second return value is a map which contains the mapping from the path parameter
// name to the extracted pattern
func sanitizePath(restfulPath string) (string, map[string]string) {
openapiPath := ""
patterns := map[string]string{}
for _, fragment := range strings.Split(restfulPath, "/") {
if fragment == "" {
continue
}
if strings.HasPrefix(fragment, "{") && strings.Contains(fragment, ":") {
split := strings.Split(fragment, ":")
fragment = split[0][1:]
pattern := split[1][:len(split[1])-1]
patterns[fragment] = pattern
fragment = "{" + fragment + "}"
}
openapiPath += "/" + fragment
}
return openapiPath, patterns
}
func buildPathItem(ws *restful.WebService, r restful.Route, existingPathItem spec.PathItem, patterns map[string]string, cfg Config) spec.PathItem {
op := buildOperation(ws, r, patterns, cfg)
switch r.Method {
case "GET":
existingPathItem.Get = op
case "POST":
existingPathItem.Post = op
case "PUT":
existingPathItem.Put = op
case "DELETE":
existingPathItem.Delete = op
case "PATCH":
existingPathItem.Patch = op
case "OPTIONS":
existingPathItem.Options = op
case "HEAD":
existingPathItem.Head = op
}
return existingPathItem
}
func buildOperation(ws *restful.WebService, r restful.Route, patterns map[string]string, cfg Config) *spec.Operation {
o := spec.NewOperation(r.Operation)
o.Description = r.Notes
o.Summary = stripTags(r.Doc)
o.Consumes = r.Consumes
o.Produces = r.Produces
o.Deprecated = r.Deprecated
if r.Metadata != nil {
if tags, ok := r.Metadata[KeyOpenAPITags]; ok {
if tagList, ok := tags.([]string); ok {
o.Tags = tagList
}
}
}
// collect any path parameters
for _, param := range ws.PathParameters() {
o.Parameters = append(o.Parameters, buildParameter(r, param, patterns[param.Data().Name], cfg))
}
// route specific params
for _, each := range r.ParameterDocs {
o.Parameters = append(o.Parameters, buildParameter(r, each, patterns[each.Data().Name], cfg))
}
o.Responses = new(spec.Responses)
props := &o.Responses.ResponsesProps
props.StatusCodeResponses = map[int]spec.Response{}
for k, v := range r.ResponseErrors {
r := buildResponse(v, cfg)
props.StatusCodeResponses[k] = r
if 200 == k { // any 2xx code?
o.Responses.Default = &r
}
}
if len(o.Responses.StatusCodeResponses) == 0 {
o.Responses.StatusCodeResponses[200] = spec.Response{ResponseProps: spec.ResponseProps{Description: http.StatusText(http.StatusOK)}}
}
return o
}
// stringAutoType automatically picks the correct type from an ambiguously typed
// string. Ex. numbers become int, true/false become bool, etc.
func stringAutoType(ambiguous string) interface{} {
if ambiguous == "" {
return nil
}
if parsedInt, err := strconv.ParseInt(ambiguous, 10, 64); err == nil {
return parsedInt
}
if parsedBool, err := strconv.ParseBool(ambiguous); err == nil {
return parsedBool
}
return ambiguous
}
func buildParameter(r restful.Route, restfulParam *restful.Parameter, pattern string, cfg Config) spec.Parameter {
p := spec.Parameter{}
param := restfulParam.Data()
p.In = asParamType(param.Kind)
p.Description = param.Description
p.Name = param.Name
p.Required = param.Required
if param.Kind == restful.PathParameterKind {
p.Pattern = pattern
}
st := reflect.TypeOf(r.ReadSample)
if param.Kind == restful.BodyParameterKind && r.ReadSample != nil && param.DataType == st.String() {
p.Schema = new(spec.Schema)
p.SimpleSchema = spec.SimpleSchema{}
if st.Kind() == reflect.Array || st.Kind() == reflect.Slice {
dataTypeName := keyFrom(st.Elem(), cfg)
p.Schema.Type = []string{"array"}
p.Schema.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{},
}
isPrimitive := isPrimitiveType(dataTypeName)
if isPrimitive {
mapped := jsonSchemaType(dataTypeName)
p.Schema.Items.Schema.Type = []string{mapped}
} else {
p.Schema.Items.Schema.Ref = spec.MustCreateRef("#/definitions/" + dataTypeName)
}
} else {
dataTypeName := keyFrom(st, cfg)
p.Schema.Ref = spec.MustCreateRef("#/definitions/" + dataTypeName)
}
} else {
if param.AllowMultiple {
p.Type = "array"
p.Items = spec.NewItems()
p.Items.Type = param.DataType
p.CollectionFormat = param.CollectionFormat
} else {
p.Type = param.DataType
}
p.Default = stringAutoType(param.DefaultValue)
p.Format = param.DataFormat
}
return p
}
func buildResponse(e restful.ResponseError, cfg Config) (r spec.Response) {
r.Description = e.Message
if e.Model != nil {
st := reflect.TypeOf(e.Model)
if st.Kind() == reflect.Ptr {
// For pointer type, use element type as the key; otherwise we'll
// endup with '#/definitions/*Type' which violates openapi spec.
st = st.Elem()
}
r.Schema = new(spec.Schema)
if st.Kind() == reflect.Array || st.Kind() == reflect.Slice {
modelName := keyFrom(st.Elem(), cfg)
r.Schema.Type = []string{"array"}
r.Schema.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{},
}
isPrimitive := isPrimitiveType(modelName)
if isPrimitive {
mapped := jsonSchemaType(modelName)
r.Schema.Items.Schema.Type = []string{mapped}
} else {
r.Schema.Items.Schema.Ref = spec.MustCreateRef("#/definitions/" + modelName)
}
} else {
modelName := keyFrom(st, cfg)
if isPrimitiveType(modelName) {
// If the response is a primitive type, then don't reference any definitions.
// Instead, set the schema's "type" to the model name.
r.Schema.AddType(modelName, "")
} else {
modelName := keyFrom(st, cfg)
r.Schema.Ref = spec.MustCreateRef("#/definitions/" + modelName)
}
}
}
return r
}
// stripTags takes a snippet of HTML and returns only the text content.
// For example, `<b>&lt;Hi!&gt;</b> <br>` -> `&lt;Hi!&gt; `.
func stripTags(html string) string {
re := regexp.MustCompile("<[^>]*>")
return re.ReplaceAllString(html, "")
}
func isPrimitiveType(modelName string) bool {
if len(modelName) == 0 {
return false
}
return strings.Contains("uint uint8 uint16 uint32 uint64 int int8 int16 int32 int64 float32 float64 bool string byte rune time.Time", modelName)
}
func jsonSchemaType(modelName string) string {
schemaMap := map[string]string{
"uint": "integer",
"uint8": "integer",
"uint16": "integer",
"uint32": "integer",
"uint64": "integer",
"int": "integer",
"int8": "integer",
"int16": "integer",
"int32": "integer",
"int64": "integer",
"byte": "integer",
"float64": "number",
"float32": "number",
"bool": "boolean",
"time.Time": "string",
}
mapped, ok := schemaMap[modelName]
if !ok {
return modelName // use as is (custom or struct)
}
return mapped
}

View File

@ -1,41 +0,0 @@
package restfulspec
import (
"reflect"
restful "github.com/emicklei/go-restful"
"github.com/go-openapi/spec"
)
// MapSchemaFormatFunc can be used to modify typeName at definition time.
// To use it set the SchemaFormatHandler in the config.
type MapSchemaFormatFunc func(typeName string) string
// MapModelTypeNameFunc can be used to return the desired typeName for a given
// type. It will return false if the default name should be used.
// To use it set the ModelTypeNameHandler in the config.
type MapModelTypeNameFunc func(t reflect.Type) (string, bool)
// PostBuildSwaggerObjectFunc can be used to change the creates Swagger Object
// before serving it. To use it set the PostBuildSwaggerObjectHandler in the config.
type PostBuildSwaggerObjectFunc func(s *spec.Swagger)
// Config holds service api metadata.
type Config struct {
// WebServicesURL is a DEPRECATED field; it never had any effect in this package.
WebServicesURL string
// APIPath is the path where the JSON api is avaiable , e.g. /apidocs.json
APIPath string
// api listing is constructed from this list of restful WebServices.
WebServices []*restful.WebService
// [optional] on default CORS (Cross-Origin-Resource-Sharing) is enabled.
DisableCORS bool
// Top-level API version. Is reflected in the resource listing.
APIVersion string
// [optional] If set, model builder should call this handler to get addition typename-to-swagger-format-field conversion.
SchemaFormatHandler MapSchemaFormatFunc
// [optional] If set, model builder should call this handler to retrieve the name for a given type.
ModelTypeNameHandler MapModelTypeNameFunc
// [optional] If set then call this function with the generated Swagger Object
PostBuildSwaggerObjectHandler PostBuildSwaggerObjectFunc
}

View File

@ -1,493 +0,0 @@
package restfulspec
import (
"encoding/json"
"reflect"
"strings"
"github.com/go-openapi/spec"
)
type definitionBuilder struct {
Definitions spec.Definitions
Config Config
}
// Documented is
type Documented interface {
SwaggerDoc() map[string]string
}
// Check if this structure has a method with signature func (<theModel>) SwaggerDoc() map[string]string
// If it exists, retrieve the documentation and overwrite all struct tag descriptions
func getDocFromMethodSwaggerDoc2(model reflect.Type) map[string]string {
if docable, ok := reflect.New(model).Elem().Interface().(Documented); ok {
return docable.SwaggerDoc()
}
return make(map[string]string)
}
// addModelFrom creates and adds a Schema to the builder and detects and calls
// the post build hook for customizations
func (b definitionBuilder) addModelFrom(sample interface{}) {
b.addModel(reflect.TypeOf(sample), "")
}
func (b definitionBuilder) addModel(st reflect.Type, nameOverride string) *spec.Schema {
// Turn pointers into simpler types so further checks are
// correct.
if st.Kind() == reflect.Ptr {
st = st.Elem()
}
modelName := keyFrom(st, b.Config)
if nameOverride != "" {
modelName = nameOverride
}
// no models needed for primitive types
if b.isPrimitiveType(modelName) {
return nil
}
// golang encoding/json packages says array and slice values encode as
// JSON arrays, except that []byte encodes as a base64-encoded string.
// If we see a []byte here, treat it at as a primitive type (string)
// and deal with it in buildArrayTypeProperty.
if (st.Kind() == reflect.Slice || st.Kind() == reflect.Array) &&
st.Elem().Kind() == reflect.Uint8 {
return nil
}
// see if we already have visited this model
if _, ok := b.Definitions[modelName]; ok {
return nil
}
sm := spec.Schema{
SchemaProps: spec.SchemaProps{
Required: []string{},
Properties: map[string]spec.Schema{},
},
}
// reference the model before further initializing (enables recursive structs)
b.Definitions[modelName] = sm
// check for slice or array
if st.Kind() == reflect.Slice || st.Kind() == reflect.Array {
st = st.Elem()
}
// check for structure or primitive type
if st.Kind() != reflect.Struct {
return &sm
}
fullDoc := getDocFromMethodSwaggerDoc2(st)
modelDescriptions := []string{}
for i := 0; i < st.NumField(); i++ {
field := st.Field(i)
jsonName, modelDescription, prop := b.buildProperty(field, &sm, modelName)
if len(modelDescription) > 0 {
modelDescriptions = append(modelDescriptions, modelDescription)
}
// add if not omitted
if len(jsonName) != 0 {
// update description
if fieldDoc, ok := fullDoc[jsonName]; ok {
prop.Description = fieldDoc
}
// update Required
if b.isPropertyRequired(field) {
sm.Required = append(sm.Required, jsonName)
}
sm.Properties[jsonName] = prop
}
}
// We always overwrite documentation if SwaggerDoc method exists
// "" is special for documenting the struct itself
if modelDoc, ok := fullDoc[""]; ok {
sm.Description = modelDoc
} else if len(modelDescriptions) != 0 {
sm.Description = strings.Join(modelDescriptions, "\n")
}
// Needed to pass openapi validation. This field exists for json-schema compatibility,
// but it conflicts with the openapi specification.
// See https://github.com/go-openapi/spec/issues/23 for more context
sm.ID = ""
// update model builder with completed model
b.Definitions[modelName] = sm
return &sm
}
func (b definitionBuilder) isPropertyRequired(field reflect.StructField) bool {
required := true
if optionalTag := field.Tag.Get("optional"); optionalTag == "true" {
return false
}
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
s := strings.Split(jsonTag, ",")
if len(s) > 1 && s[1] == "omitempty" {
return false
}
}
return required
}
func (b definitionBuilder) buildProperty(field reflect.StructField, model *spec.Schema, modelName string) (jsonName, modelDescription string, prop spec.Schema) {
jsonName = b.jsonNameOfField(field)
if len(jsonName) == 0 {
// empty name signals skip property
return "", "", prop
}
if field.Name == "XMLName" && field.Type.String() == "xml.Name" {
// property is metadata for the xml.Name attribute, can be skipped
return "", "", prop
}
if tag := field.Tag.Get("modelDescription"); tag != "" {
modelDescription = tag
}
setPropertyMetadata(&prop, field)
if prop.Type != nil {
return jsonName, modelDescription, prop
}
fieldType := field.Type
// check if type is doing its own marshalling
marshalerType := reflect.TypeOf((*json.Marshaler)(nil)).Elem()
if fieldType.Implements(marshalerType) {
var pType = "string"
if prop.Type == nil {
prop.Type = []string{pType}
}
if prop.Format == "" {
prop.Format = b.jsonSchemaFormat(keyFrom(fieldType, b.Config))
}
return jsonName, modelDescription, prop
}
// check if annotation says it is a string
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
s := strings.Split(jsonTag, ",")
if len(s) > 1 && s[1] == "string" {
stringt := "string"
prop.Type = []string{stringt}
return jsonName, modelDescription, prop
}
}
fieldKind := fieldType.Kind()
switch {
case fieldKind == reflect.Struct:
jsonName, prop := b.buildStructTypeProperty(field, jsonName, model)
return jsonName, modelDescription, prop
case fieldKind == reflect.Slice || fieldKind == reflect.Array:
jsonName, prop := b.buildArrayTypeProperty(field, jsonName, modelName)
return jsonName, modelDescription, prop
case fieldKind == reflect.Ptr:
jsonName, prop := b.buildPointerTypeProperty(field, jsonName, modelName)
return jsonName, modelDescription, prop
case fieldKind == reflect.String:
stringt := "string"
prop.Type = []string{stringt}
return jsonName, modelDescription, prop
case fieldKind == reflect.Map:
jsonName, prop := b.buildMapTypeProperty(field, jsonName, modelName)
return jsonName, modelDescription, prop
}
fieldTypeName := keyFrom(fieldType, b.Config)
if b.isPrimitiveType(fieldTypeName) {
mapped := b.jsonSchemaType(fieldTypeName)
prop.Type = []string{mapped}
prop.Format = b.jsonSchemaFormat(fieldTypeName)
return jsonName, modelDescription, prop
}
modelType := keyFrom(fieldType, b.Config)
prop.Ref = spec.MustCreateRef("#/definitions/" + modelType)
if fieldType.Name() == "" { // override type of anonymous structs
// FIXME: Still need a way to handle anonymous struct model naming.
nestedTypeName := modelName + "." + jsonName
prop.Ref = spec.MustCreateRef("#/definitions/" + nestedTypeName)
b.addModel(fieldType, nestedTypeName)
}
return jsonName, modelDescription, prop
}
func hasNamedJSONTag(field reflect.StructField) bool {
parts := strings.Split(field.Tag.Get("json"), ",")
if len(parts) == 0 {
return false
}
for _, s := range parts[1:] {
if s == "inline" {
return false
}
}
return len(parts[0]) > 0
}
func (b definitionBuilder) buildStructTypeProperty(field reflect.StructField, jsonName string, model *spec.Schema) (nameJson string, prop spec.Schema) {
setPropertyMetadata(&prop, field)
fieldType := field.Type
// check for anonymous
if len(fieldType.Name()) == 0 {
// anonymous
// FIXME: Still need a way to handle anonymous struct model naming.
anonType := model.ID + "." + jsonName
b.addModel(fieldType, anonType)
prop.Ref = spec.MustCreateRef("#/definitions/" + anonType)
return jsonName, prop
}
if field.Name == fieldType.Name() && field.Anonymous && !hasNamedJSONTag(field) {
// embedded struct
sub := definitionBuilder{make(spec.Definitions), b.Config}
sub.addModel(fieldType, "")
subKey := keyFrom(fieldType, b.Config)
// merge properties from sub
subModel, _ := sub.Definitions[subKey]
for k, v := range subModel.Properties {
model.Properties[k] = v
// if subModel says this property is required then include it
required := false
for _, each := range subModel.Required {
if k == each {
required = true
break
}
}
if required {
model.Required = append(model.Required, k)
}
}
// add all new referenced models
for key, sub := range sub.Definitions {
if key != subKey {
if _, ok := b.Definitions[key]; !ok {
b.Definitions[key] = sub
}
}
}
// empty name signals skip property
return "", prop
}
// simple struct
b.addModel(fieldType, "")
var pType = keyFrom(fieldType, b.Config)
prop.Ref = spec.MustCreateRef("#/definitions/" + pType)
return jsonName, prop
}
func (b definitionBuilder) buildArrayTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop spec.Schema) {
setPropertyMetadata(&prop, field)
fieldType := field.Type
if fieldType.Elem().Kind() == reflect.Uint8 {
stringt := "string"
prop.Type = []string{stringt}
return jsonName, prop
}
var pType = "array"
prop.Type = []string{pType}
isPrimitive := b.isPrimitiveType(fieldType.Elem().Name())
elemTypeName := b.getElementTypeName(modelName, jsonName, fieldType.Elem())
prop.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{},
}
if isPrimitive {
mapped := b.jsonSchemaType(elemTypeName)
prop.Items.Schema.Type = []string{mapped}
} else {
prop.Items.Schema.Ref = spec.MustCreateRef("#/definitions/" + elemTypeName)
}
// add|overwrite model for element type
if fieldType.Elem().Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if !isPrimitive {
b.addModel(fieldType.Elem(), elemTypeName)
}
return jsonName, prop
}
func (b definitionBuilder) buildMapTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop spec.Schema) {
setPropertyMetadata(&prop, field)
fieldType := field.Type
var pType = "object"
prop.Type = []string{pType}
// As long as the element isn't an interface, we should be able to figure out what the
// intended type is and represent it in `AdditionalProperties`.
// See: https://swagger.io/docs/specification/data-models/dictionaries/
if fieldType.Elem().Kind().String() != "interface" {
isPrimitive := b.isPrimitiveType(fieldType.Elem().Name())
elemTypeName := b.getElementTypeName(modelName, jsonName, fieldType.Elem())
prop.AdditionalProperties = &spec.SchemaOrBool{
Schema: &spec.Schema{},
}
if isPrimitive {
mapped := b.jsonSchemaType(elemTypeName)
prop.AdditionalProperties.Schema.Type = []string{mapped}
} else {
prop.AdditionalProperties.Schema.Ref = spec.MustCreateRef("#/definitions/" + elemTypeName)
}
// add|overwrite model for element type
if fieldType.Elem().Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if !isPrimitive {
b.addModel(fieldType.Elem(), elemTypeName)
}
}
return jsonName, prop
}
func (b definitionBuilder) buildPointerTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop spec.Schema) {
setPropertyMetadata(&prop, field)
fieldType := field.Type
// override type of pointer to list-likes
if fieldType.Elem().Kind() == reflect.Slice || fieldType.Elem().Kind() == reflect.Array {
var pType = "array"
prop.Type = []string{pType}
isPrimitive := b.isPrimitiveType(fieldType.Elem().Elem().Name())
elemName := b.getElementTypeName(modelName, jsonName, fieldType.Elem().Elem())
prop.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{},
}
if isPrimitive {
primName := b.jsonSchemaType(elemName)
prop.Items.Schema.Type = []string{primName}
} else {
prop.Items.Schema.Ref = spec.MustCreateRef("#/definitions/" + elemName)
}
if !isPrimitive {
// add|overwrite model for element type
b.addModel(fieldType.Elem().Elem(), elemName)
}
} else {
// non-array, pointer type
fieldTypeName := keyFrom(fieldType.Elem(), b.Config)
var pType = b.jsonSchemaType(fieldTypeName) // no star, include pkg path
if b.isPrimitiveType(fieldTypeName) {
prop.Type = []string{pType}
prop.Format = b.jsonSchemaFormat(fieldTypeName)
return jsonName, prop
}
prop.Ref = spec.MustCreateRef("#/definitions/" + pType)
elemName := ""
if fieldType.Elem().Name() == "" {
elemName = modelName + "." + jsonName
prop.Ref = spec.MustCreateRef("#/definitions/" + elemName)
}
b.addModel(fieldType.Elem(), elemName)
}
return jsonName, prop
}
func (b definitionBuilder) getElementTypeName(modelName, jsonName string, t reflect.Type) string {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Name() == "" {
return modelName + "." + jsonName
}
return keyFrom(t, b.Config)
}
func keyFrom(st reflect.Type, cfg Config) string {
key := st.String()
if cfg.ModelTypeNameHandler != nil {
if name, ok := cfg.ModelTypeNameHandler(st); ok {
key = name
}
}
if len(st.Name()) == 0 { // unnamed type
// If it is an array, remove the leading []
key = strings.TrimPrefix(key, "[]")
// Swagger UI has special meaning for [
key = strings.Replace(key, "[]", "||", -1)
}
return key
}
// see also https://golang.org/ref/spec#Numeric_types
func (b definitionBuilder) isPrimitiveType(modelName string) bool {
if len(modelName) == 0 {
return false
}
return strings.Contains("uint uint8 uint16 uint32 uint64 int int8 int16 int32 int64 float32 float64 bool string byte rune time.Time", modelName)
}
// jsonNameOfField returns the name of the field as it should appear in JSON format
// An empty string indicates that this field is not part of the JSON representation
func (b definitionBuilder) jsonNameOfField(field reflect.StructField) string {
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
s := strings.Split(jsonTag, ",")
if s[0] == "-" {
// empty name signals skip property
return ""
} else if s[0] != "" {
return s[0]
}
}
return field.Name
}
// see also http://json-schema.org/latest/json-schema-core.html#anchor8
func (b definitionBuilder) jsonSchemaType(modelName string) string {
schemaMap := map[string]string{
"uint": "integer",
"uint8": "integer",
"uint16": "integer",
"uint32": "integer",
"uint64": "integer",
"int": "integer",
"int8": "integer",
"int16": "integer",
"int32": "integer",
"int64": "integer",
"byte": "integer",
"float64": "number",
"float32": "number",
"bool": "boolean",
"time.Time": "string",
}
mapped, ok := schemaMap[modelName]
if !ok {
return modelName // use as is (custom or struct)
}
return mapped
}
func (b definitionBuilder) jsonSchemaFormat(modelName string) string {
if b.Config.SchemaFormatHandler != nil {
if mapped := b.Config.SchemaFormatHandler(modelName); mapped != "" {
return mapped
}
}
schemaMap := map[string]string{
"int": "int32",
"int32": "int32",
"int64": "int64",
"byte": "byte",
"uint": "integer",
"uint8": "byte",
"float64": "double",
"float32": "float",
"time.Time": "date-time",
"*time.Time": "date-time",
}
mapped, ok := schemaMap[modelName]
if !ok {
return "" // no format
}
return mapped
}

View File

@ -1,19 +0,0 @@
package restfulspec
import restful "github.com/emicklei/go-restful"
func asParamType(kind int) string {
switch {
case kind == restful.PathParameterKind:
return "path"
case kind == restful.QueryParameterKind:
return "query"
case kind == restful.BodyParameterKind:
return "body"
case kind == restful.HeaderParameterKind:
return "header"
case kind == restful.FormParameterKind:
return "formData"
}
return ""
}

View File

@ -1,104 +0,0 @@
package restfulspec
import (
"reflect"
"strconv"
"strings"
"github.com/go-openapi/spec"
)
func setDescription(prop *spec.Schema, field reflect.StructField) {
if tag := field.Tag.Get("description"); tag != "" {
prop.Description = tag
}
}
func setDefaultValue(prop *spec.Schema, field reflect.StructField) {
if tag := field.Tag.Get("default"); tag != "" {
prop.Default = stringAutoType(tag)
}
}
func setEnumValues(prop *spec.Schema, field reflect.StructField) {
// We use | to separate the enum values. This value is chosen
// since its unlikely to be useful in actual enumeration values.
if tag := field.Tag.Get("enum"); tag != "" {
enums := []interface{}{}
for _, s := range strings.Split(tag, "|") {
enums = append(enums, s)
}
prop.Enum = enums
}
}
func setMaximum(prop *spec.Schema, field reflect.StructField) {
if tag := field.Tag.Get("maximum"); tag != "" {
value, err := strconv.ParseFloat(tag, 64)
if err == nil {
prop.Maximum = &value
}
}
}
func setMinimum(prop *spec.Schema, field reflect.StructField) {
if tag := field.Tag.Get("minimum"); tag != "" {
value, err := strconv.ParseFloat(tag, 64)
if err == nil {
prop.Minimum = &value
}
}
}
func setType(prop *spec.Schema, field reflect.StructField) {
if tag := field.Tag.Get("type"); tag != "" {
// Check if the first two characters of the type tag are
// intended to emulate slice/array behaviour.
//
// If type is intended to be a slice/array then add the
// overriden type to the array item instead of the main property
if len(tag) > 2 && tag[0:2] == "[]" {
pType := "array"
prop.Type = []string{pType}
prop.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{},
}
iType := tag[2:]
prop.Items.Schema.Type = []string{iType}
return
}
prop.Type = []string{tag}
}
}
func setUniqueItems(prop *spec.Schema, field reflect.StructField) {
tag := field.Tag.Get("unique")
switch tag {
case "true":
prop.UniqueItems = true
case "false":
prop.UniqueItems = false
}
}
func setReadOnly(prop *spec.Schema, field reflect.StructField) {
tag := field.Tag.Get("readOnly")
switch tag {
case "true":
prop.ReadOnly = true
case "false":
prop.ReadOnly = false
}
}
func setPropertyMetadata(prop *spec.Schema, field reflect.StructField) {
setDescription(prop, field)
setDefaultValue(prop, field)
setEnumValues(prop, field)
setMinimum(prop, field)
setMaximum(prop, field)
setUniqueItems(prop, field)
setType(prop, field)
setReadOnly(prop, field)
}

View File

@ -1,76 +0,0 @@
package restfulspec
import (
restful "github.com/emicklei/go-restful"
"github.com/go-openapi/spec"
)
// NewOpenAPIService returns a new WebService that provides the API documentation of all services
// conform the OpenAPI documentation specifcation.
func NewOpenAPIService(config Config) *restful.WebService {
ws := new(restful.WebService)
ws.Path(config.APIPath)
ws.Produces(restful.MIME_JSON)
if config.DisableCORS {
ws.Filter(enableCORS)
}
swagger := BuildSwagger(config)
resource := specResource{swagger: swagger}
ws.Route(ws.GET("/").To(resource.getSwagger))
return ws
}
// BuildSwagger returns a Swagger object for all services' API endpoints.
func BuildSwagger(config Config) *spec.Swagger {
// collect paths and model definitions to build Swagger object.
paths := &spec.Paths{Paths: map[string]spec.PathItem{}}
definitions := spec.Definitions{}
for _, each := range config.WebServices {
for path, item := range buildPaths(each, config).Paths {
existingPathItem, ok := paths.Paths[path]
if ok {
for _, r := range each.Routes() {
_, patterns := sanitizePath(r.Path)
item = buildPathItem(each, r, existingPathItem, patterns, config)
}
}
paths.Paths[path] = item
}
for name, def := range buildDefinitions(each, config) {
definitions[name] = def
}
}
swagger := &spec.Swagger{
SwaggerProps: spec.SwaggerProps{
Swagger: "2.0",
Paths: paths,
Definitions: definitions,
},
}
if config.PostBuildSwaggerObjectHandler != nil {
config.PostBuildSwaggerObjectHandler(swagger)
}
return swagger
}
func enableCORS(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
if origin := req.HeaderParameter(restful.HEADER_Origin); origin != "" {
// prevent duplicate header
if len(resp.Header().Get(restful.HEADER_AccessControlAllowOrigin)) == 0 {
resp.AddHeader(restful.HEADER_AccessControlAllowOrigin, origin)
}
}
chain.ProcessFilter(req, resp)
}
// specResource is a REST resource to serve the Open-API spec.
type specResource struct {
swagger *spec.Swagger
}
func (s specResource) getSwagger(req *restful.Request, resp *restful.Response) {
resp.WriteAsJson(s.swagger)
}

View File

@ -1,70 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
restful.html
*.out
tmp.prof
go-restful.test
examples/restful-basic-authentication
examples/restful-encoding-filter
examples/restful-filters
examples/restful-hello-world
examples/restful-resource-functions
examples/restful-serve-static
examples/restful-user-service
*.DS_Store
examples/restful-user-resource
examples/restful-multi-containers
examples/restful-form-handling
examples/restful-CORS-filter
examples/restful-options-filter
examples/restful-curly-router
examples/restful-cpuprofiler-service
examples/restful-pre-post-filters
curly.prof
examples/restful-NCSA-logging
examples/restful-html-template
s.html
restful-path-tail

View File

@ -1,6 +0,0 @@
language: go
go:
- 1.x
script: go test -v

View File

@ -1,246 +0,0 @@
Change history of go-restful
=
v2.8.1
- Fix Parameter 'AllowableValues' to populate swagger definition
v2.8.0
- add Request.QueryParameters()
- add json-iterator (via build tag)
- disable vgo module (until log is moved)
v2.7.1
- add vgo module
v2.6.1
- add JSONNewDecoderFunc to allow custom JSON Decoder usage (go 1.10+)
v2.6.0
- Make JSR 311 routing and path param processing consistent
- Adding description to RouteBuilder.Reads()
- Update example for Swagger12 and OpenAPI
2017-09-13
- added route condition functions using `.If(func)` in route building.
2017-02-16
- solved issue #304, make operation names unique
2017-01-30
[IMPORTANT] For swagger users, change your import statement to:
swagger "github.com/emicklei/go-restful-swagger12"
- moved swagger 1.2 code to go-restful-swagger12
- created TAG 2.0.0
2017-01-27
- remove defer request body close
- expose Dispatch for testing filters and Routefunctions
- swagger response model cannot be array
- created TAG 1.0.0
2016-12-22
- (API change) Remove code related to caching request content. Removes SetCacheReadEntity(doCache bool)
2016-11-26
- Default change! now use CurlyRouter (was RouterJSR311)
- Default change! no more caching of request content
- Default change! do not recover from panics
2016-09-22
- fix the DefaultRequestContentType feature
2016-02-14
- take the qualify factor of the Accept header mediatype into account when deciding the contentype of the response
- add constructors for custom entity accessors for xml and json
2015-09-27
- rename new WriteStatusAnd... to WriteHeaderAnd... for consistency
2015-09-25
- fixed problem with changing Header after WriteHeader (issue 235)
2015-09-14
- changed behavior of WriteHeader (immediate write) and WriteEntity (no status write)
- added support for custom EntityReaderWriters.
2015-08-06
- add support for reading entities from compressed request content
- use sync.Pool for compressors of http response and request body
- add Description to Parameter for documentation in Swagger UI
2015-03-20
- add configurable logging
2015-03-18
- if not specified, the Operation is derived from the Route function
2015-03-17
- expose Parameter creation functions
- make trace logger an interface
- fix OPTIONSFilter
- customize rendering of ServiceError
- JSR311 router now handles wildcards
- add Notes to Route
2014-11-27
- (api add) PrettyPrint per response. (as proposed in #167)
2014-11-12
- (api add) ApiVersion(.) for documentation in Swagger UI
2014-11-10
- (api change) struct fields tagged with "description" show up in Swagger UI
2014-10-31
- (api change) ReturnsError -> Returns
- (api add) RouteBuilder.Do(aBuilder) for DRY use of RouteBuilder
- fix swagger nested structs
- sort Swagger response messages by code
2014-10-23
- (api add) ReturnsError allows you to document Http codes in swagger
- fixed problem with greedy CurlyRouter
- (api add) Access-Control-Max-Age in CORS
- add tracing functionality (injectable) for debugging purposes
- support JSON parse 64bit int
- fix empty parameters for swagger
- WebServicesUrl is now optional for swagger
- fixed duplicate AccessControlAllowOrigin in CORS
- (api change) expose ServeMux in container
- (api add) added AllowedDomains in CORS
- (api add) ParameterNamed for detailed documentation
2014-04-16
- (api add) expose constructor of Request for testing.
2014-06-27
- (api add) ParameterNamed gives access to a Parameter definition and its data (for further specification).
- (api add) SetCacheReadEntity allow scontrol over whether or not the request body is being cached (default true for compatibility reasons).
2014-07-03
- (api add) CORS can be configured with a list of allowed domains
2014-03-12
- (api add) Route path parameters can use wildcard or regular expressions. (requires CurlyRouter)
2014-02-26
- (api add) Request now provides information about the matched Route, see method SelectedRoutePath
2014-02-17
- (api change) renamed parameter constants (go-lint checks)
2014-01-10
- (api add) support for CloseNotify, see http://golang.org/pkg/net/http/#CloseNotifier
2014-01-07
- (api change) Write* methods in Response now return the error or nil.
- added example of serving HTML from a Go template.
- fixed comparing Allowed headers in CORS (is now case-insensitive)
2013-11-13
- (api add) Response knows how many bytes are written to the response body.
2013-10-29
- (api add) RecoverHandler(handler RecoverHandleFunction) to change how panic recovery is handled. Default behavior is to log and return a stacktrace. This may be a security issue as it exposes sourcecode information.
2013-10-04
- (api add) Response knows what HTTP status has been written
- (api add) Request can have attributes (map of string->interface, also called request-scoped variables
2013-09-12
- (api change) Router interface simplified
- Implemented CurlyRouter, a Router that does not use|allow regular expressions in paths
2013-08-05
- add OPTIONS support
- add CORS support
2013-08-27
- fixed some reported issues (see github)
- (api change) deprecated use of WriteError; use WriteErrorString instead
2014-04-15
- (fix) v1.0.1 tag: fix Issue 111: WriteErrorString
2013-08-08
- (api add) Added implementation Container: a WebServices collection with its own http.ServeMux allowing multiple endpoints per program. Existing uses of go-restful will register their services to the DefaultContainer.
- (api add) the swagger package has be extended to have a UI per container.
- if panic is detected then a small stack trace is printed (thanks to runner-mei)
- (api add) WriteErrorString to Response
Important API changes:
- (api remove) package variable DoNotRecover no longer works ; use restful.DefaultContainer.DoNotRecover(true) instead.
- (api remove) package variable EnableContentEncoding no longer works ; use restful.DefaultContainer.EnableContentEncoding(true) instead.
2013-07-06
- (api add) Added support for response encoding (gzip and deflate(zlib)). This feature is disabled on default (for backwards compatibility). Use restful.EnableContentEncoding = true in your initialization to enable this feature.
2013-06-19
- (improve) DoNotRecover option, moved request body closer, improved ReadEntity
2013-06-03
- (api change) removed Dispatcher interface, hide PathExpression
- changed receiver names of type functions to be more idiomatic Go
2013-06-02
- (optimize) Cache the RegExp compilation of Paths.
2013-05-22
- (api add) Added support for request/response filter functions
2013-05-18
- (api add) Added feature to change the default Http Request Dispatch function (travis cline)
- (api change) Moved Swagger Webservice to swagger package (see example restful-user)
[2012-11-14 .. 2013-05-18>
- See https://github.com/emicklei/go-restful/commits
2012-11-14
- Initial commit

View File

@ -1,22 +0,0 @@
Copyright (c) 2012,2013 Ernest Micklei
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,7 +0,0 @@
all: test
test:
go test -v .
ex:
cd examples && ls *.go | xargs go build -o /tmp/ignore

View File

@ -1,88 +0,0 @@
go-restful
==========
package for building REST-style Web Services using Google Go
[![Build Status](https://travis-ci.org/emicklei/go-restful.png)](https://travis-ci.org/emicklei/go-restful)
[![Go Report Card](https://goreportcard.com/badge/github.com/emicklei/go-restful)](https://goreportcard.com/report/github.com/emicklei/go-restful)
[![GoDoc](https://godoc.org/github.com/emicklei/go-restful?status.svg)](https://godoc.org/github.com/emicklei/go-restful)
- [Code examples](https://github.com/emicklei/go-restful/tree/master/examples)
REST asks developers to use HTTP methods explicitly and in a way that's consistent with the protocol definition. This basic REST design principle establishes a one-to-one mapping between create, read, update, and delete (CRUD) operations and HTTP methods. According to this mapping:
- GET = Retrieve a representation of a resource
- POST = Create if you are sending content to the server to create a subordinate of the specified resource collection, using some server-side algorithm.
- PUT = Create if you are sending the full content of the specified resource (URI).
- PUT = Update if you are updating the full content of the specified resource.
- DELETE = Delete if you are requesting the server to delete the resource
- PATCH = Update partial content of a resource
- OPTIONS = Get information about the communication options for the request URI
### Example
```Go
ws := new(restful.WebService)
ws.
Path("/users").
Consumes(restful.MIME_XML, restful.MIME_JSON).
Produces(restful.MIME_JSON, restful.MIME_XML)
ws.Route(ws.GET("/{user-id}").To(u.findUser).
Doc("get a user").
Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
Writes(User{}))
...
func (u UserResource) findUser(request *restful.Request, response *restful.Response) {
id := request.PathParameter("user-id")
...
}
```
[Full API of a UserResource](https://github.com/emicklei/go-restful/tree/master/examples/restful-user-resource.go)
### Features
- Routes for request &#8594; function mapping with path parameter (e.g. {id}) support
- Configurable router:
- (default) Fast routing algorithm that allows static elements, regular expressions and dynamic parameters in the URL path (e.g. /meetings/{id} or /static/{subpath:*}
- Routing algorithm after [JSR311](http://jsr311.java.net/nonav/releases/1.1/spec/spec.html) that is implemented using (but does **not** accept) regular expressions
- Request API for reading structs from JSON/XML and accesing parameters (path,query,header)
- Response API for writing structs to JSON/XML and setting headers
- Customizable encoding using EntityReaderWriter registration
- Filters for intercepting the request &#8594; response flow on Service or Route level
- Request-scoped variables using attributes
- Containers for WebServices on different HTTP endpoints
- Content encoding (gzip,deflate) of request and response payloads
- Automatic responses on OPTIONS (using a filter)
- Automatic CORS request handling (using a filter)
- API declaration for Swagger UI ([go-restful-openapi](https://github.com/emicklei/go-restful-openapi), see [go-restful-swagger12](https://github.com/emicklei/go-restful-swagger12))
- Panic recovery to produce HTTP 500, customizable using RecoverHandler(...)
- Route errors produce HTTP 404/405/406/415 errors, customizable using ServiceErrorHandler(...)
- Configurable (trace) logging
- Customizable gzip/deflate readers and writers using CompressorProvider registration
## How to customize
There are several hooks to customize the behavior of the go-restful package.
- Router algorithm
- Panic recovery
- JSON decoder
- Trace logging
- Compression
- Encoders for other serializers
- Use [jsoniter](https://github.com/json-iterator/go) by build this package using a tag, e.g. `go build -tags=jsoniter .`
TODO: write examples of these.
## Resources
- [Example posted on blog](http://ernestmicklei.com/2012/11/go-restful-first-working-example/)
- [Design explained on blog](http://ernestmicklei.com/2012/11/go-restful-api-design/)
- [sourcegraph](https://sourcegraph.com/github.com/emicklei/go-restful)
- [showcase: Zazkia - tcp proxy for testing resiliency](https://github.com/emicklei/zazkia)
- [showcase: Mora - MongoDB REST Api server](https://github.com/emicklei/mora)
Type ```git shortlog -s``` for a full list of contributors.
© 2012 - 2018, http://ernestmicklei.com. MIT License. Contributions are welcome.

View File

@ -1 +0,0 @@
{"SkipDirs": ["examples"]}

View File

@ -1,10 +0,0 @@
#go test -run=none -file bench_test.go -test.bench . -cpuprofile=bench_test.out
go test -c
./go-restful.test -test.run=none -test.cpuprofile=tmp.prof -test.bench=BenchmarkMany
./go-restful.test -test.run=none -test.cpuprofile=curly.prof -test.bench=BenchmarkManyCurly
#go tool pprof go-restful.test tmp.prof
go tool pprof go-restful.test curly.prof

View File

@ -1,123 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"bufio"
"compress/gzip"
"compress/zlib"
"errors"
"io"
"net"
"net/http"
"strings"
)
// OBSOLETE : use restful.DefaultContainer.EnableContentEncoding(true) to change this setting.
var EnableContentEncoding = false
// CompressingResponseWriter is a http.ResponseWriter that can perform content encoding (gzip and zlib)
type CompressingResponseWriter struct {
writer http.ResponseWriter
compressor io.WriteCloser
encoding string
}
// Header is part of http.ResponseWriter interface
func (c *CompressingResponseWriter) Header() http.Header {
return c.writer.Header()
}
// WriteHeader is part of http.ResponseWriter interface
func (c *CompressingResponseWriter) WriteHeader(status int) {
c.writer.WriteHeader(status)
}
// Write is part of http.ResponseWriter interface
// It is passed through the compressor
func (c *CompressingResponseWriter) Write(bytes []byte) (int, error) {
if c.isCompressorClosed() {
return -1, errors.New("Compressing error: tried to write data using closed compressor")
}
return c.compressor.Write(bytes)
}
// CloseNotify is part of http.CloseNotifier interface
func (c *CompressingResponseWriter) CloseNotify() <-chan bool {
return c.writer.(http.CloseNotifier).CloseNotify()
}
// Close the underlying compressor
func (c *CompressingResponseWriter) Close() error {
if c.isCompressorClosed() {
return errors.New("Compressing error: tried to close already closed compressor")
}
c.compressor.Close()
if ENCODING_GZIP == c.encoding {
currentCompressorProvider.ReleaseGzipWriter(c.compressor.(*gzip.Writer))
}
if ENCODING_DEFLATE == c.encoding {
currentCompressorProvider.ReleaseZlibWriter(c.compressor.(*zlib.Writer))
}
// gc hint needed?
c.compressor = nil
return nil
}
func (c *CompressingResponseWriter) isCompressorClosed() bool {
return nil == c.compressor
}
// Hijack implements the Hijacker interface
// This is especially useful when combining Container.EnabledContentEncoding
// in combination with websockets (for instance gorilla/websocket)
func (c *CompressingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := c.writer.(http.Hijacker)
if !ok {
return nil, nil, errors.New("ResponseWriter doesn't support Hijacker interface")
}
return hijacker.Hijack()
}
// WantsCompressedResponse reads the Accept-Encoding header to see if and which encoding is requested.
func wantsCompressedResponse(httpRequest *http.Request) (bool, string) {
header := httpRequest.Header.Get(HEADER_AcceptEncoding)
gi := strings.Index(header, ENCODING_GZIP)
zi := strings.Index(header, ENCODING_DEFLATE)
// use in order of appearance
if gi == -1 {
return zi != -1, ENCODING_DEFLATE
} else if zi == -1 {
return gi != -1, ENCODING_GZIP
} else {
if gi < zi {
return true, ENCODING_GZIP
}
return true, ENCODING_DEFLATE
}
}
// NewCompressingResponseWriter create a CompressingResponseWriter for a known encoding = {gzip,deflate}
func NewCompressingResponseWriter(httpWriter http.ResponseWriter, encoding string) (*CompressingResponseWriter, error) {
httpWriter.Header().Set(HEADER_ContentEncoding, encoding)
c := new(CompressingResponseWriter)
c.writer = httpWriter
var err error
if ENCODING_GZIP == encoding {
w := currentCompressorProvider.AcquireGzipWriter()
w.Reset(httpWriter)
c.compressor = w
c.encoding = ENCODING_GZIP
} else if ENCODING_DEFLATE == encoding {
w := currentCompressorProvider.AcquireZlibWriter()
w.Reset(httpWriter)
c.compressor = w
c.encoding = ENCODING_DEFLATE
} else {
return nil, errors.New("Unknown encoding:" + encoding)
}
return c, err
}

View File

@ -1,103 +0,0 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"compress/gzip"
"compress/zlib"
)
// BoundedCachedCompressors is a CompressorProvider that uses a cache with a fixed amount
// of writers and readers (resources).
// If a new resource is acquired and all are in use, it will return a new unmanaged resource.
type BoundedCachedCompressors struct {
gzipWriters chan *gzip.Writer
gzipReaders chan *gzip.Reader
zlibWriters chan *zlib.Writer
writersCapacity int
readersCapacity int
}
// NewBoundedCachedCompressors returns a new, with filled cache, BoundedCachedCompressors.
func NewBoundedCachedCompressors(writersCapacity, readersCapacity int) *BoundedCachedCompressors {
b := &BoundedCachedCompressors{
gzipWriters: make(chan *gzip.Writer, writersCapacity),
gzipReaders: make(chan *gzip.Reader, readersCapacity),
zlibWriters: make(chan *zlib.Writer, writersCapacity),
writersCapacity: writersCapacity,
readersCapacity: readersCapacity,
}
for ix := 0; ix < writersCapacity; ix++ {
b.gzipWriters <- newGzipWriter()
b.zlibWriters <- newZlibWriter()
}
for ix := 0; ix < readersCapacity; ix++ {
b.gzipReaders <- newGzipReader()
}
return b
}
// AcquireGzipWriter returns an resettable *gzip.Writer. Needs to be released.
func (b *BoundedCachedCompressors) AcquireGzipWriter() *gzip.Writer {
var writer *gzip.Writer
select {
case writer, _ = <-b.gzipWriters:
default:
// return a new unmanaged one
writer = newGzipWriter()
}
return writer
}
// ReleaseGzipWriter accepts a writer (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseGzipWriter(w *gzip.Writer) {
// forget the unmanaged ones
if len(b.gzipWriters) < b.writersCapacity {
b.gzipWriters <- w
}
}
// AcquireGzipReader returns a *gzip.Reader. Needs to be released.
func (b *BoundedCachedCompressors) AcquireGzipReader() *gzip.Reader {
var reader *gzip.Reader
select {
case reader, _ = <-b.gzipReaders:
default:
// return a new unmanaged one
reader = newGzipReader()
}
return reader
}
// ReleaseGzipReader accepts a reader (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseGzipReader(r *gzip.Reader) {
// forget the unmanaged ones
if len(b.gzipReaders) < b.readersCapacity {
b.gzipReaders <- r
}
}
// AcquireZlibWriter returns an resettable *zlib.Writer. Needs to be released.
func (b *BoundedCachedCompressors) AcquireZlibWriter() *zlib.Writer {
var writer *zlib.Writer
select {
case writer, _ = <-b.zlibWriters:
default:
// return a new unmanaged one
writer = newZlibWriter()
}
return writer
}
// ReleaseZlibWriter accepts a writer (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseZlibWriter(w *zlib.Writer) {
// forget the unmanaged ones
if len(b.zlibWriters) < b.writersCapacity {
b.zlibWriters <- w
}
}

View File

@ -1,91 +0,0 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"bytes"
"compress/gzip"
"compress/zlib"
"sync"
)
// SyncPoolCompessors is a CompressorProvider that use the standard sync.Pool.
type SyncPoolCompessors struct {
GzipWriterPool *sync.Pool
GzipReaderPool *sync.Pool
ZlibWriterPool *sync.Pool
}
// NewSyncPoolCompessors returns a new ("empty") SyncPoolCompessors.
func NewSyncPoolCompessors() *SyncPoolCompessors {
return &SyncPoolCompessors{
GzipWriterPool: &sync.Pool{
New: func() interface{} { return newGzipWriter() },
},
GzipReaderPool: &sync.Pool{
New: func() interface{} { return newGzipReader() },
},
ZlibWriterPool: &sync.Pool{
New: func() interface{} { return newZlibWriter() },
},
}
}
func (s *SyncPoolCompessors) AcquireGzipWriter() *gzip.Writer {
return s.GzipWriterPool.Get().(*gzip.Writer)
}
func (s *SyncPoolCompessors) ReleaseGzipWriter(w *gzip.Writer) {
s.GzipWriterPool.Put(w)
}
func (s *SyncPoolCompessors) AcquireGzipReader() *gzip.Reader {
return s.GzipReaderPool.Get().(*gzip.Reader)
}
func (s *SyncPoolCompessors) ReleaseGzipReader(r *gzip.Reader) {
s.GzipReaderPool.Put(r)
}
func (s *SyncPoolCompessors) AcquireZlibWriter() *zlib.Writer {
return s.ZlibWriterPool.Get().(*zlib.Writer)
}
func (s *SyncPoolCompessors) ReleaseZlibWriter(w *zlib.Writer) {
s.ZlibWriterPool.Put(w)
}
func newGzipWriter() *gzip.Writer {
// create with an empty bytes writer; it will be replaced before using the gzipWriter
writer, err := gzip.NewWriterLevel(new(bytes.Buffer), gzip.BestSpeed)
if err != nil {
panic(err.Error())
}
return writer
}
func newGzipReader() *gzip.Reader {
// create with an empty reader (but with GZIP header); it will be replaced before using the gzipReader
// we can safely use currentCompressProvider because it is set on package initialization.
w := currentCompressorProvider.AcquireGzipWriter()
defer currentCompressorProvider.ReleaseGzipWriter(w)
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
reader, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
panic(err.Error())
}
return reader
}
func newZlibWriter() *zlib.Writer {
writer, err := zlib.NewWriterLevel(new(bytes.Buffer), gzip.BestSpeed)
if err != nil {
panic(err.Error())
}
return writer
}

View File

@ -1,54 +0,0 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"compress/gzip"
"compress/zlib"
)
// CompressorProvider describes a component that can provider compressors for the std methods.
type CompressorProvider interface {
// Returns a *gzip.Writer which needs to be released later.
// Before using it, call Reset().
AcquireGzipWriter() *gzip.Writer
// Releases an acquired *gzip.Writer.
ReleaseGzipWriter(w *gzip.Writer)
// Returns a *gzip.Reader which needs to be released later.
AcquireGzipReader() *gzip.Reader
// Releases an acquired *gzip.Reader.
ReleaseGzipReader(w *gzip.Reader)
// Returns a *zlib.Writer which needs to be released later.
// Before using it, call Reset().
AcquireZlibWriter() *zlib.Writer
// Releases an acquired *zlib.Writer.
ReleaseZlibWriter(w *zlib.Writer)
}
// DefaultCompressorProvider is the actual provider of compressors (zlib or gzip).
var currentCompressorProvider CompressorProvider
func init() {
currentCompressorProvider = NewSyncPoolCompessors()
}
// CurrentCompressorProvider returns the current CompressorProvider.
// It is initialized using a SyncPoolCompessors.
func CurrentCompressorProvider() CompressorProvider {
return currentCompressorProvider
}
// SetCompressorProvider sets the actual provider of compressors (zlib or gzip).
func SetCompressorProvider(p CompressorProvider) {
if p == nil {
panic("cannot set compressor provider to nil")
}
currentCompressorProvider = p
}

View File

@ -1,30 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
const (
MIME_XML = "application/xml" // Accept or Content-Type used in Consumes() and/or Produces()
MIME_JSON = "application/json" // Accept or Content-Type used in Consumes() and/or Produces()
MIME_OCTET = "application/octet-stream" // If Content-Type is not present in request, use the default
HEADER_Allow = "Allow"
HEADER_Accept = "Accept"
HEADER_Origin = "Origin"
HEADER_ContentType = "Content-Type"
HEADER_LastModified = "Last-Modified"
HEADER_AcceptEncoding = "Accept-Encoding"
HEADER_ContentEncoding = "Content-Encoding"
HEADER_AccessControlExposeHeaders = "Access-Control-Expose-Headers"
HEADER_AccessControlRequestMethod = "Access-Control-Request-Method"
HEADER_AccessControlRequestHeaders = "Access-Control-Request-Headers"
HEADER_AccessControlAllowMethods = "Access-Control-Allow-Methods"
HEADER_AccessControlAllowOrigin = "Access-Control-Allow-Origin"
HEADER_AccessControlAllowCredentials = "Access-Control-Allow-Credentials"
HEADER_AccessControlAllowHeaders = "Access-Control-Allow-Headers"
HEADER_AccessControlMaxAge = "Access-Control-Max-Age"
ENCODING_GZIP = "gzip"
ENCODING_DEFLATE = "deflate"
)

View File

@ -1,371 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"bytes"
"errors"
"fmt"
"net/http"
"os"
"runtime"
"strings"
"sync"
"github.com/emicklei/go-restful/log"
)
// Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
// The requests are further dispatched to routes of WebServices using a RouteSelector
type Container struct {
webServicesLock sync.RWMutex
webServices []*WebService
ServeMux *http.ServeMux
isRegisteredOnRoot bool
containerFilters []FilterFunction
doNotRecover bool // default is true
recoverHandleFunc RecoverHandleFunction
serviceErrorHandleFunc ServiceErrorHandleFunction
router RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
contentEncodingEnabled bool // default is false
}
// NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
func NewContainer() *Container {
return &Container{
webServices: []*WebService{},
ServeMux: http.NewServeMux(),
isRegisteredOnRoot: false,
containerFilters: []FilterFunction{},
doNotRecover: true,
recoverHandleFunc: logStackOnRecover,
serviceErrorHandleFunc: writeServiceError,
router: CurlyRouter{},
contentEncodingEnabled: false}
}
// RecoverHandleFunction declares functions that can be used to handle a panic situation.
// The first argument is what recover() returns. The second must be used to communicate an error response.
type RecoverHandleFunction func(interface{}, http.ResponseWriter)
// RecoverHandler changes the default function (logStackOnRecover) to be called
// when a panic is detected. DoNotRecover must be have its default value (=false).
func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
c.recoverHandleFunc = handler
}
// ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
// The first argument is the service error, the second is the request that resulted in the error and
// the third must be used to communicate an error response.
type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
// ServiceErrorHandler changes the default function (writeServiceError) to be called
// when a ServiceError is detected.
func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
c.serviceErrorHandleFunc = handler
}
// DoNotRecover controls whether panics will be caught to return HTTP 500.
// If set to true, Route functions are responsible for handling any error situation.
// Default value is true.
func (c *Container) DoNotRecover(doNot bool) {
c.doNotRecover = doNot
}
// Router changes the default Router (currently CurlyRouter)
func (c *Container) Router(aRouter RouteSelector) {
c.router = aRouter
}
// EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
func (c *Container) EnableContentEncoding(enabled bool) {
c.contentEncodingEnabled = enabled
}
// Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
func (c *Container) Add(service *WebService) *Container {
c.webServicesLock.Lock()
defer c.webServicesLock.Unlock()
// if rootPath was not set then lazy initialize it
if len(service.rootPath) == 0 {
service.Path("/")
}
// cannot have duplicate root paths
for _, each := range c.webServices {
if each.RootPath() == service.RootPath() {
log.Printf("WebService with duplicate root path detected:['%v']", each)
os.Exit(1)
}
}
// If not registered on root then add specific mapping
if !c.isRegisteredOnRoot {
c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
}
c.webServices = append(c.webServices, service)
return c
}
// addHandler may set a new HandleFunc for the serveMux
// this function must run inside the critical region protected by the webServicesLock.
// returns true if the function was registered on root ("/")
func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
pattern := fixedPrefixPath(service.RootPath())
// check if root path registration is needed
if "/" == pattern || "" == pattern {
serveMux.HandleFunc("/", c.dispatch)
return true
}
// detect if registration already exists
alreadyMapped := false
for _, each := range c.webServices {
if each.RootPath() == service.RootPath() {
alreadyMapped = true
break
}
}
if !alreadyMapped {
serveMux.HandleFunc(pattern, c.dispatch)
if !strings.HasSuffix(pattern, "/") {
serveMux.HandleFunc(pattern+"/", c.dispatch)
}
}
return false
}
func (c *Container) Remove(ws *WebService) error {
if c.ServeMux == http.DefaultServeMux {
errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
log.Print(errMsg)
return errors.New(errMsg)
}
c.webServicesLock.Lock()
defer c.webServicesLock.Unlock()
// build a new ServeMux and re-register all WebServices
newServeMux := http.NewServeMux()
newServices := []*WebService{}
newIsRegisteredOnRoot := false
for _, each := range c.webServices {
if each.rootPath != ws.rootPath {
// If not registered on root then add specific mapping
if !newIsRegisteredOnRoot {
newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
}
newServices = append(newServices, each)
}
}
c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
return nil
}
// logStackOnRecover is the default RecoverHandleFunction and is called
// when DoNotRecover is false and the recoverHandleFunc is not set for the container.
// Default implementation logs the stacktrace and writes the stacktrace on the response.
// This may be a security issue as it exposes sourcecode information.
func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
var buffer bytes.Buffer
buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
for i := 2; ; i += 1 {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
}
log.Print(buffer.String())
httpWriter.WriteHeader(http.StatusInternalServerError)
httpWriter.Write(buffer.Bytes())
}
// writeServiceError is the default ServiceErrorHandleFunction and is called
// when a ServiceError is returned during route selection. Default implementation
// calls resp.WriteErrorString(err.Code, err.Message)
func writeServiceError(err ServiceError, req *Request, resp *Response) {
resp.WriteErrorString(err.Code, err.Message)
}
// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
if httpWriter == nil {
panic("httpWriter cannot be nil")
}
if httpRequest == nil {
panic("httpRequest cannot be nil")
}
c.dispatch(httpWriter, httpRequest)
}
// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
writer := httpWriter
// CompressingResponseWriter should be closed after all operations are done
defer func() {
if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
compressWriter.Close()
}
}()
// Instal panic recovery unless told otherwise
if !c.doNotRecover { // catch all for 500 response
defer func() {
if r := recover(); r != nil {
c.recoverHandleFunc(r, writer)
return
}
}()
}
// Detect if compression is needed
// assume without compression, test for override
if c.contentEncodingEnabled {
doCompress, encoding := wantsCompressedResponse(httpRequest)
if doCompress {
var err error
writer, err = NewCompressingResponseWriter(httpWriter, encoding)
if err != nil {
log.Print("unable to install compressor: ", err)
httpWriter.WriteHeader(http.StatusInternalServerError)
return
}
}
}
// Find best match Route ; err is non nil if no match was found
var webService *WebService
var route *Route
var err error
func() {
c.webServicesLock.RLock()
defer c.webServicesLock.RUnlock()
webService, route, err = c.router.SelectRoute(
c.webServices,
httpRequest)
}()
if err != nil {
// a non-200 response has already been written
// run container filters anyway ; they should not touch the response...
chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
switch err.(type) {
case ServiceError:
ser := err.(ServiceError)
c.serviceErrorHandleFunc(ser, req, resp)
}
// TODO
}}
chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
return
}
pathProcessor, routerProcessesPath := c.router.(PathProcessor)
if !routerProcessesPath {
pathProcessor = defaultPathProcessor{}
}
pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
// pass through filters (if any)
if len(c.containerFilters)+len(webService.filters)+len(route.Filters) > 0 {
// compose filter chain
allFilters := []FilterFunction{}
allFilters = append(allFilters, c.containerFilters...)
allFilters = append(allFilters, webService.filters...)
allFilters = append(allFilters, route.Filters...)
chain := FilterChain{Filters: allFilters, Target: func(req *Request, resp *Response) {
// handle request by route after passing all filters
route.Function(wrappedRequest, wrappedResponse)
}}
chain.ProcessFilter(wrappedRequest, wrappedResponse)
} else {
// no filters, handle request by route
route.Function(wrappedRequest, wrappedResponse)
}
}
// fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
func fixedPrefixPath(pathspec string) string {
varBegin := strings.Index(pathspec, "{")
if -1 == varBegin {
return pathspec
}
return pathspec[:varBegin]
}
// ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
c.ServeMux.ServeHTTP(httpwriter, httpRequest)
}
// Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
func (c *Container) Handle(pattern string, handler http.Handler) {
c.ServeMux.Handle(pattern, handler)
}
// HandleWithFilter registers the handler for the given pattern.
// Container's filter chain is applied for handler.
// If a handler already exists for pattern, HandleWithFilter panics.
func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
if len(c.containerFilters) == 0 {
handler.ServeHTTP(httpResponse, httpRequest)
return
}
chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
handler.ServeHTTP(httpResponse, httpRequest)
}}
chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
}
c.Handle(pattern, http.HandlerFunc(f))
}
// Filter appends a container FilterFunction. These are called before dispatching
// a http.Request to a WebService from the container
func (c *Container) Filter(filter FilterFunction) {
c.containerFilters = append(c.containerFilters, filter)
}
// RegisteredWebServices returns the collections of added WebServices
func (c *Container) RegisteredWebServices() []*WebService {
c.webServicesLock.RLock()
defer c.webServicesLock.RUnlock()
result := make([]*WebService, len(c.webServices))
for ix := range c.webServices {
result[ix] = c.webServices[ix]
}
return result
}
// computeAllowedMethods returns a list of HTTP methods that are valid for a Request
func (c *Container) computeAllowedMethods(req *Request) []string {
// Go through all RegisteredWebServices() and all its Routes to collect the options
methods := []string{}
requestPath := req.Request.URL.Path
for _, ws := range c.RegisteredWebServices() {
matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
if matches != nil {
finalMatch := matches[len(matches)-1]
for _, rt := range ws.Routes() {
matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
if matches != nil {
lastMatch := matches[len(matches)-1]
if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor /.
methods = append(methods, rt.Method)
}
}
}
}
}
// methods = append(methods, "OPTIONS") not sure about this
return methods
}
// newBasicRequestResponse creates a pair of Request,Response from its http versions.
// It is basic because no parameter or (produces) content-type information is given.
func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
resp := NewResponse(httpWriter)
resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
return NewRequest(httpRequest), resp
}

View File

@ -1,202 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"regexp"
"strconv"
"strings"
)
// CrossOriginResourceSharing is used to create a Container Filter that implements CORS.
// Cross-origin resource sharing (CORS) is a mechanism that allows JavaScript on a web page
// to make XMLHttpRequests to another domain, not the domain the JavaScript originated from.
//
// http://en.wikipedia.org/wiki/Cross-origin_resource_sharing
// http://enable-cors.org/server.html
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names
AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
AllowedMethods []string
MaxAge int // number of seconds before requiring new Options request
CookiesAllowed bool
Container *Container
allowedOriginPatterns []*regexp.Regexp // internal field for origin regexp check.
}
// Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html
// and http://www.html5rocks.com/static/images/cors_server_flowchart.png
func (c CrossOriginResourceSharing) Filter(req *Request, resp *Response, chain *FilterChain) {
origin := req.Request.Header.Get(HEADER_Origin)
if len(origin) == 0 {
if trace {
traceLogger.Print("no Http header Origin set")
}
chain.ProcessFilter(req, resp)
return
}
if !c.isOriginAllowed(origin) { // check whether this origin is allowed
if trace {
traceLogger.Printf("HTTP Origin:%s is not part of %v, neither matches any part of %v", origin, c.AllowedDomains, c.allowedOriginPatterns)
}
chain.ProcessFilter(req, resp)
return
}
if req.Request.Method != "OPTIONS" {
c.doActualRequest(req, resp)
chain.ProcessFilter(req, resp)
return
}
if acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod); acrm != "" {
c.doPreflightRequest(req, resp)
} else {
c.doActualRequest(req, resp)
chain.ProcessFilter(req, resp)
return
}
}
func (c CrossOriginResourceSharing) doActualRequest(req *Request, resp *Response) {
c.setOptionsHeaders(req, resp)
// continue processing the response
}
func (c *CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response) {
if len(c.AllowedMethods) == 0 {
if c.Container == nil {
c.AllowedMethods = DefaultContainer.computeAllowedMethods(req)
} else {
c.AllowedMethods = c.Container.computeAllowedMethods(req)
}
}
acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod)
if !c.isValidAccessControlRequestMethod(acrm, c.AllowedMethods) {
if trace {
traceLogger.Printf("Http header %s:%s is not in %v",
HEADER_AccessControlRequestMethod,
acrm,
c.AllowedMethods)
}
return
}
acrhs := req.Request.Header.Get(HEADER_AccessControlRequestHeaders)
if len(acrhs) > 0 {
for _, each := range strings.Split(acrhs, ",") {
if !c.isValidAccessControlRequestHeader(strings.Trim(each, " ")) {
if trace {
traceLogger.Printf("Http header %s:%s is not in %v",
HEADER_AccessControlRequestHeaders,
acrhs,
c.AllowedHeaders)
}
return
}
}
}
resp.AddHeader(HEADER_AccessControlAllowMethods, strings.Join(c.AllowedMethods, ","))
resp.AddHeader(HEADER_AccessControlAllowHeaders, acrhs)
c.setOptionsHeaders(req, resp)
// return http 200 response, no body
}
func (c CrossOriginResourceSharing) setOptionsHeaders(req *Request, resp *Response) {
c.checkAndSetExposeHeaders(resp)
c.setAllowOriginHeader(req, resp)
c.checkAndSetAllowCredentials(resp)
if c.MaxAge > 0 {
resp.AddHeader(HEADER_AccessControlMaxAge, strconv.Itoa(c.MaxAge))
}
}
func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
if len(origin) == 0 {
return false
}
if len(c.AllowedDomains) == 0 {
return true
}
allowed := false
for _, domain := range c.AllowedDomains {
if domain == origin {
allowed = true
break
}
}
if !allowed {
if len(c.allowedOriginPatterns) == 0 {
// compile allowed domains to allowed origin patterns
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
if err != nil {
return false
}
c.allowedOriginPatterns = allowedOriginRegexps
}
for _, pattern := range c.allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
}
return allowed
}
func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
origin := req.Request.Header.Get(HEADER_Origin)
if c.isOriginAllowed(origin) {
resp.AddHeader(HEADER_AccessControlAllowOrigin, origin)
}
}
func (c CrossOriginResourceSharing) checkAndSetExposeHeaders(resp *Response) {
if len(c.ExposeHeaders) > 0 {
resp.AddHeader(HEADER_AccessControlExposeHeaders, strings.Join(c.ExposeHeaders, ","))
}
}
func (c CrossOriginResourceSharing) checkAndSetAllowCredentials(resp *Response) {
if c.CookiesAllowed {
resp.AddHeader(HEADER_AccessControlAllowCredentials, "true")
}
}
func (c CrossOriginResourceSharing) isValidAccessControlRequestMethod(method string, allowedMethods []string) bool {
for _, each := range allowedMethods {
if each == method {
return true
}
}
return false
}
func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header string) bool {
for _, each := range c.AllowedHeaders {
if strings.ToLower(each) == strings.ToLower(header) {
return true
}
}
return false
}
// Take a list of strings and compile them into a list of regular expressions.
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return regexps, err
}
regexps = append(regexps, r)
}
return regexps, nil
}

View File

@ -1,2 +0,0 @@
go test -coverprofile=coverage.out
go tool cover -html=coverage.out

View File

@ -1,164 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"net/http"
"regexp"
"sort"
"strings"
)
// CurlyRouter expects Routes with paths that contain zero or more parameters in curly brackets.
type CurlyRouter struct{}
// SelectRoute is part of the Router interface and returns the best match
// for the WebService and its Route for the given Request.
func (c CurlyRouter) SelectRoute(
webServices []*WebService,
httpRequest *http.Request) (selectedService *WebService, selected *Route, err error) {
requestTokens := tokenizePath(httpRequest.URL.Path)
detectedService := c.detectWebService(requestTokens, webServices)
if detectedService == nil {
if trace {
traceLogger.Printf("no WebService was found to match URL path:%s\n", httpRequest.URL.Path)
}
return nil, nil, NewError(http.StatusNotFound, "404: Page Not Found")
}
candidateRoutes := c.selectRoutes(detectedService, requestTokens)
if len(candidateRoutes) == 0 {
if trace {
traceLogger.Printf("no Route in WebService with path %s was found to match URL path:%s\n", detectedService.rootPath, httpRequest.URL.Path)
}
return detectedService, nil, NewError(http.StatusNotFound, "404: Page Not Found")
}
selectedRoute, err := c.detectRoute(candidateRoutes, httpRequest)
if selectedRoute == nil {
return detectedService, nil, err
}
return detectedService, selectedRoute, nil
}
// selectRoutes return a collection of Route from a WebService that matches the path tokens from the request.
func (c CurlyRouter) selectRoutes(ws *WebService, requestTokens []string) sortableCurlyRoutes {
candidates := sortableCurlyRoutes{}
for _, each := range ws.routes {
matches, paramCount, staticCount := c.matchesRouteByPathTokens(each.pathParts, requestTokens)
if matches {
candidates.add(curlyRoute{each, paramCount, staticCount}) // TODO make sure Routes() return pointers?
}
}
sort.Sort(sort.Reverse(candidates))
return candidates
}
// matchesRouteByPathTokens computes whether it matches, howmany parameters do match and what the number of static path elements are.
func (c CurlyRouter) matchesRouteByPathTokens(routeTokens, requestTokens []string) (matches bool, paramCount int, staticCount int) {
if len(routeTokens) < len(requestTokens) {
// proceed in matching only if last routeToken is wildcard
count := len(routeTokens)
if count == 0 || !strings.HasSuffix(routeTokens[count-1], "*}") {
return false, 0, 0
}
// proceed
}
for i, routeToken := range routeTokens {
if i == len(requestTokens) {
// reached end of request path
return false, 0, 0
}
requestToken := requestTokens[i]
if strings.HasPrefix(routeToken, "{") {
paramCount++
if colon := strings.Index(routeToken, ":"); colon != -1 {
// match by regex
matchesToken, matchesRemainder := c.regularMatchesPathToken(routeToken, colon, requestToken)
if !matchesToken {
return false, 0, 0
}
if matchesRemainder {
break
}
}
} else { // no { prefix
if requestToken != routeToken {
return false, 0, 0
}
staticCount++
}
}
return true, paramCount, staticCount
}
// regularMatchesPathToken tests whether the regular expression part of routeToken matches the requestToken or all remaining tokens
// format routeToken is {someVar:someExpression}, e.g. {zipcode:[\d][\d][\d][\d][A-Z][A-Z]}
func (c CurlyRouter) regularMatchesPathToken(routeToken string, colon int, requestToken string) (matchesToken bool, matchesRemainder bool) {
regPart := routeToken[colon+1 : len(routeToken)-1]
if regPart == "*" {
if trace {
traceLogger.Printf("wildcard parameter detected in route token %s that matches %s\n", routeToken, requestToken)
}
return true, true
}
matched, err := regexp.MatchString(regPart, requestToken)
return (matched && err == nil), false
}
var jsr311Router = RouterJSR311{}
// detectRoute selectes from a list of Route the first match by inspecting both the Accept and Content-Type
// headers of the Request. See also RouterJSR311 in jsr311.go
func (c CurlyRouter) detectRoute(candidateRoutes sortableCurlyRoutes, httpRequest *http.Request) (*Route, error) {
// tracing is done inside detectRoute
return jsr311Router.detectRoute(candidateRoutes.routes(), httpRequest)
}
// detectWebService returns the best matching webService given the list of path tokens.
// see also computeWebserviceScore
func (c CurlyRouter) detectWebService(requestTokens []string, webServices []*WebService) *WebService {
var best *WebService
score := -1
for _, each := range webServices {
matches, eachScore := c.computeWebserviceScore(requestTokens, each.pathExpr.tokens)
if matches && (eachScore > score) {
best = each
score = eachScore
}
}
return best
}
// computeWebserviceScore returns whether tokens match and
// the weighted score of the longest matching consecutive tokens from the beginning.
func (c CurlyRouter) computeWebserviceScore(requestTokens []string, tokens []string) (bool, int) {
if len(tokens) > len(requestTokens) {
return false, 0
}
score := 0
for i := 0; i < len(tokens); i++ {
each := requestTokens[i]
other := tokens[i]
if len(each) == 0 && len(other) == 0 {
score++
continue
}
if len(other) > 0 && strings.HasPrefix(other, "{") {
// no empty match
if len(each) == 0 {
return false, score
}
score += 1
} else {
// not a parameter
if each != other {
return false, score
}
score += (len(tokens) - i) * 10 //fuzzy
}
}
return true, score
}

View File

@ -1,52 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
// curlyRoute exits for sorting Routes by the CurlyRouter based on number of parameters and number of static path elements.
type curlyRoute struct {
route Route
paramCount int
staticCount int
}
type sortableCurlyRoutes []curlyRoute
func (s *sortableCurlyRoutes) add(route curlyRoute) {
*s = append(*s, route)
}
func (s sortableCurlyRoutes) routes() (routes []Route) {
for _, each := range s {
routes = append(routes, each.route) // TODO change return type
}
return routes
}
func (s sortableCurlyRoutes) Len() int {
return len(s)
}
func (s sortableCurlyRoutes) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s sortableCurlyRoutes) Less(i, j int) bool {
ci := s[i]
cj := s[j]
// primary key
if ci.staticCount < cj.staticCount {
return true
}
if ci.staticCount > cj.staticCount {
return false
}
// secundary key
if ci.paramCount < cj.paramCount {
return true
}
if ci.paramCount > cj.paramCount {
return false
}
return ci.route.Path < cj.route.Path
}

View File

@ -1,185 +0,0 @@
/*
Package restful , a lean package for creating REST-style WebServices without magic.
WebServices and Routes
A WebService has a collection of Route objects that dispatch incoming Http Requests to a function calls.
Typically, a WebService has a root path (e.g. /users) and defines common MIME types for its routes.
WebServices must be added to a container (see below) in order to handler Http requests from a server.
A Route is defined by a HTTP method, an URL path and (optionally) the MIME types it consumes (Content-Type) and produces (Accept).
This package has the logic to find the best matching Route and if found, call its Function.
ws := new(restful.WebService)
ws.
Path("/users").
Consumes(restful.MIME_JSON, restful.MIME_XML).
Produces(restful.MIME_JSON, restful.MIME_XML)
ws.Route(ws.GET("/{user-id}").To(u.findUser)) // u is a UserResource
...
// GET http://localhost:8080/users/1
func (u UserResource) findUser(request *restful.Request, response *restful.Response) {
id := request.PathParameter("user-id")
...
}
The (*Request, *Response) arguments provide functions for reading information from the request and writing information back to the response.
See the example https://github.com/emicklei/go-restful/blob/master/examples/restful-user-resource.go with a full implementation.
Regular expression matching Routes
A Route parameter can be specified using the format "uri/{var[:regexp]}" or the special version "uri/{var:*}" for matching the tail of the path.
For example, /persons/{name:[A-Z][A-Z]} can be used to restrict values for the parameter "name" to only contain capital alphabetic characters.
Regular expressions must use the standard Go syntax as described in the regexp package. (https://code.google.com/p/re2/wiki/Syntax)
This feature requires the use of a CurlyRouter.
Containers
A Container holds a collection of WebServices, Filters and a http.ServeMux for multiplexing http requests.
Using the statements "restful.Add(...) and restful.Filter(...)" will register WebServices and Filters to the Default Container.
The Default container of go-restful uses the http.DefaultServeMux.
You can create your own Container and create a new http.Server for that particular container.
container := restful.NewContainer()
server := &http.Server{Addr: ":8081", Handler: container}
Filters
A filter dynamically intercepts requests and responses to transform or use the information contained in the requests or responses.
You can use filters to perform generic logging, measurement, authentication, redirect, set response headers etc.
In the restful package there are three hooks into the request,response flow where filters can be added.
Each filter must define a FilterFunction:
func (req *restful.Request, resp *restful.Response, chain *restful.FilterChain)
Use the following statement to pass the request,response pair to the next filter or RouteFunction
chain.ProcessFilter(req, resp)
Container Filters
These are processed before any registered WebService.
// install a (global) filter for the default container (processed before any webservice)
restful.Filter(globalLogging)
WebService Filters
These are processed before any Route of a WebService.
// install a webservice filter (processed before any route)
ws.Filter(webserviceLogging).Filter(measureTime)
Route Filters
These are processed before calling the function associated with the Route.
// install 2 chained route filters (processed before calling findUser)
ws.Route(ws.GET("/{user-id}").Filter(routeLogging).Filter(NewCountFilter().routeCounter).To(findUser))
See the example https://github.com/emicklei/go-restful/blob/master/examples/restful-filters.go with full implementations.
Response Encoding
Two encodings are supported: gzip and deflate. To enable this for all responses:
restful.DefaultContainer.EnableContentEncoding(true)
If a Http request includes the Accept-Encoding header then the response content will be compressed using the specified encoding.
Alternatively, you can create a Filter that performs the encoding and install it per WebService or Route.
See the example https://github.com/emicklei/go-restful/blob/master/examples/restful-encoding-filter.go
OPTIONS support
By installing a pre-defined container filter, your Webservice(s) can respond to the OPTIONS Http request.
Filter(OPTIONSFilter())
CORS
By installing the filter of a CrossOriginResourceSharing (CORS), your WebService(s) can handle CORS requests.
cors := CrossOriginResourceSharing{ExposeHeaders: []string{"X-My-Header"}, CookiesAllowed: false, Container: DefaultContainer}
Filter(cors.Filter)
Error Handling
Unexpected things happen. If a request cannot be processed because of a failure, your service needs to tell via the response what happened and why.
For this reason HTTP status codes exist and it is important to use the correct code in every exceptional situation.
400: Bad Request
If path or query parameters are not valid (content or type) then use http.StatusBadRequest.
404: Not Found
Despite a valid URI, the resource requested may not be available
500: Internal Server Error
If the application logic could not process the request (or write the response) then use http.StatusInternalServerError.
405: Method Not Allowed
The request has a valid URL but the method (GET,PUT,POST,...) is not allowed.
406: Not Acceptable
The request does not have or has an unknown Accept Header set for this operation.
415: Unsupported Media Type
The request does not have or has an unknown Content-Type Header set for this operation.
ServiceError
In addition to setting the correct (error) Http status code, you can choose to write a ServiceError message on the response.
Performance options
This package has several options that affect the performance of your service. It is important to understand them and how you can change it.
restful.DefaultContainer.DoNotRecover(false)
DoNotRecover controls whether panics will be caught to return HTTP 500.
If set to false, the container will recover from panics.
Default value is true
restful.SetCompressorProvider(NewBoundedCachedCompressors(20, 20))
If content encoding is enabled then the default strategy for getting new gzip/zlib writers and readers is to use a sync.Pool.
Because writers are expensive structures, performance is even more improved when using a preloaded cache. You can also inject your own implementation.
Trouble shooting
This package has the means to produce detail logging of the complete Http request matching process and filter invocation.
Enabling this feature requires you to set an implementation of restful.StdLogger (e.g. log.Logger) instance such as:
restful.TraceLogger(log.New(os.Stdout, "[restful] ", log.LstdFlags|log.Lshortfile))
Logging
The restful.SetLogger() method allows you to override the logger used by the package. By default restful
uses the standard library `log` package and logs to stdout. Different logging packages are supported as
long as they conform to `StdLogger` interface defined in the `log` sub-package, writing an adapter for your
preferred package is simple.
Resources
[project]: https://github.com/emicklei/go-restful
[examples]: https://github.com/emicklei/go-restful/blob/master/examples
[design]: http://ernestmicklei.com/2012/11/11/go-restful-api-design/
[showcases]: https://github.com/emicklei/mora, https://github.com/emicklei/landskape
(c) 2012-2015, http://ernestmicklei.com. MIT License
*/
package restful

View File

@ -1,162 +0,0 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"encoding/xml"
"strings"
"sync"
)
// EntityReaderWriter can read and write values using an encoding such as JSON,XML.
type EntityReaderWriter interface {
// Read a serialized version of the value from the request.
// The Request may have a decompressing reader. Depends on Content-Encoding.
Read(req *Request, v interface{}) error
// Write a serialized version of the value on the response.
// The Response may have a compressing writer. Depends on Accept-Encoding.
// status should be a valid Http Status code
Write(resp *Response, status int, v interface{}) error
}
// entityAccessRegistry is a singleton
var entityAccessRegistry = &entityReaderWriters{
protection: new(sync.RWMutex),
accessors: map[string]EntityReaderWriter{},
}
// entityReaderWriters associates MIME to an EntityReaderWriter
type entityReaderWriters struct {
protection *sync.RWMutex
accessors map[string]EntityReaderWriter
}
func init() {
RegisterEntityAccessor(MIME_JSON, NewEntityAccessorJSON(MIME_JSON))
RegisterEntityAccessor(MIME_XML, NewEntityAccessorXML(MIME_XML))
}
// RegisterEntityAccessor add/overrides the ReaderWriter for encoding content with this MIME type.
func RegisterEntityAccessor(mime string, erw EntityReaderWriter) {
entityAccessRegistry.protection.Lock()
defer entityAccessRegistry.protection.Unlock()
entityAccessRegistry.accessors[mime] = erw
}
// NewEntityAccessorJSON returns a new EntityReaderWriter for accessing JSON content.
// This package is already initialized with such an accessor using the MIME_JSON contentType.
func NewEntityAccessorJSON(contentType string) EntityReaderWriter {
return entityJSONAccess{ContentType: contentType}
}
// NewEntityAccessorXML returns a new EntityReaderWriter for accessing XML content.
// This package is already initialized with such an accessor using the MIME_XML contentType.
func NewEntityAccessorXML(contentType string) EntityReaderWriter {
return entityXMLAccess{ContentType: contentType}
}
// accessorAt returns the registered ReaderWriter for this MIME type.
func (r *entityReaderWriters) accessorAt(mime string) (EntityReaderWriter, bool) {
r.protection.RLock()
defer r.protection.RUnlock()
er, ok := r.accessors[mime]
if !ok {
// retry with reverse lookup
// more expensive but we are in an exceptional situation anyway
for k, v := range r.accessors {
if strings.Contains(mime, k) {
return v, true
}
}
}
return er, ok
}
// entityXMLAccess is a EntityReaderWriter for XML encoding
type entityXMLAccess struct {
// This is used for setting the Content-Type header when writing
ContentType string
}
// Read unmarshalls the value from XML
func (e entityXMLAccess) Read(req *Request, v interface{}) error {
return xml.NewDecoder(req.Request.Body).Decode(v)
}
// Write marshalls the value to JSON and set the Content-Type Header.
func (e entityXMLAccess) Write(resp *Response, status int, v interface{}) error {
return writeXML(resp, status, e.ContentType, v)
}
// writeXML marshalls the value to JSON and set the Content-Type Header.
func writeXML(resp *Response, status int, contentType string, v interface{}) error {
if v == nil {
resp.WriteHeader(status)
// do not write a nil representation
return nil
}
if resp.prettyPrint {
// pretty output must be created and written explicitly
output, err := xml.MarshalIndent(v, " ", " ")
if err != nil {
return err
}
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
_, err = resp.Write([]byte(xml.Header))
if err != nil {
return err
}
_, err = resp.Write(output)
return err
}
// not-so-pretty
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
return xml.NewEncoder(resp).Encode(v)
}
// entityJSONAccess is a EntityReaderWriter for JSON encoding
type entityJSONAccess struct {
// This is used for setting the Content-Type header when writing
ContentType string
}
// Read unmarshalls the value from JSON
func (e entityJSONAccess) Read(req *Request, v interface{}) error {
decoder := NewDecoder(req.Request.Body)
decoder.UseNumber()
return decoder.Decode(v)
}
// Write marshalls the value to JSON and set the Content-Type Header.
func (e entityJSONAccess) Write(resp *Response, status int, v interface{}) error {
return writeJSON(resp, status, e.ContentType, v)
}
// write marshalls the value to JSON and set the Content-Type Header.
func writeJSON(resp *Response, status int, contentType string, v interface{}) error {
if v == nil {
resp.WriteHeader(status)
// do not write a nil representation
return nil
}
if resp.prettyPrint {
// pretty output must be created and written explicitly
output, err := MarshalIndent(v, "", " ")
if err != nil {
return err
}
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
_, err = resp.Write(output)
return err
}
// not-so-pretty
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
return NewEncoder(resp).Encode(v)
}

View File

@ -1,35 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
// FilterChain is a request scoped object to process one or more filters before calling the target RouteFunction.
type FilterChain struct {
Filters []FilterFunction // ordered list of FilterFunction
Index int // index into filters that is currently in progress
Target RouteFunction // function to call after passing all filters
}
// ProcessFilter passes the request,response pair through the next of Filters.
// Each filter can decide to proceed to the next Filter or handle the Response itself.
func (f *FilterChain) ProcessFilter(request *Request, response *Response) {
if f.Index < len(f.Filters) {
f.Index++
f.Filters[f.Index-1](request, response, f)
} else {
f.Target(request, response)
}
}
// FilterFunction definitions must call ProcessFilter on the FilterChain to pass on the control and eventually call the RouteFunction
type FilterFunction func(*Request, *Response, *FilterChain)
// NoBrowserCacheFilter is a filter function to set HTTP headers that disable browser caching
// See examples/restful-no-cache-filter.go for usage
func NoBrowserCacheFilter(req *Request, resp *Response, chain *FilterChain) {
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") // HTTP 1.1.
resp.Header().Set("Pragma", "no-cache") // HTTP 1.0.
resp.Header().Set("Expires", "0") // Proxies.
chain.ProcessFilter(req, resp)
}

View File

@ -1,11 +0,0 @@
// +build !jsoniter
package restful
import "encoding/json"
var (
MarshalIndent = json.MarshalIndent
NewDecoder = json.NewDecoder
NewEncoder = json.NewEncoder
)

View File

@ -1,12 +0,0 @@
// +build jsoniter
package restful
import "github.com/json-iterator/go"
var (
json = jsoniter.ConfigCompatibleWithStandardLibrary
MarshalIndent = json.MarshalIndent
NewDecoder = json.NewDecoder
NewEncoder = json.NewEncoder
)

View File

@ -1,293 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"errors"
"fmt"
"net/http"
"sort"
)
// RouterJSR311 implements the flow for matching Requests to Routes (and consequently Resource Functions)
// as specified by the JSR311 http://jsr311.java.net/nonav/releases/1.1/spec/spec.html.
// RouterJSR311 implements the Router interface.
// Concept of locators is not implemented.
type RouterJSR311 struct{}
// SelectRoute is part of the Router interface and returns the best match
// for the WebService and its Route for the given Request.
func (r RouterJSR311) SelectRoute(
webServices []*WebService,
httpRequest *http.Request) (selectedService *WebService, selectedRoute *Route, err error) {
// Identify the root resource class (WebService)
dispatcher, finalMatch, err := r.detectDispatcher(httpRequest.URL.Path, webServices)
if err != nil {
return nil, nil, NewError(http.StatusNotFound, "")
}
// Obtain the set of candidate methods (Routes)
routes := r.selectRoutes(dispatcher, finalMatch)
if len(routes) == 0 {
return dispatcher, nil, NewError(http.StatusNotFound, "404: Page Not Found")
}
// Identify the method (Route) that will handle the request
route, ok := r.detectRoute(routes, httpRequest)
return dispatcher, route, ok
}
// ExtractParameters is used to obtain the path parameters from the route using the same matching
// engine as the JSR 311 router.
func (r RouterJSR311) ExtractParameters(route *Route, webService *WebService, urlPath string) map[string]string {
webServiceExpr := webService.pathExpr
webServiceMatches := webServiceExpr.Matcher.FindStringSubmatch(urlPath)
pathParameters := r.extractParams(webServiceExpr, webServiceMatches)
routeExpr := route.pathExpr
routeMatches := routeExpr.Matcher.FindStringSubmatch(webServiceMatches[len(webServiceMatches)-1])
routeParams := r.extractParams(routeExpr, routeMatches)
for key, value := range routeParams {
pathParameters[key] = value
}
return pathParameters
}
func (RouterJSR311) extractParams(pathExpr *pathExpression, matches []string) map[string]string {
params := map[string]string{}
for i := 1; i < len(matches); i++ {
if len(pathExpr.VarNames) >= i {
params[pathExpr.VarNames[i-1]] = matches[i]
}
}
return params
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2
func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*Route, error) {
ifOk := []Route{}
for _, each := range routes {
ok := true
for _, fn := range each.If {
if !fn(httpRequest) {
ok = false
break
}
}
if ok {
ifOk = append(ifOk, each)
}
}
if len(ifOk) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that passes conditional checks", len(routes))
}
return nil, NewError(http.StatusNotFound, "404: Not Found")
}
// http method
methodOk := []Route{}
for _, each := range ifOk {
if httpRequest.Method == each.Method {
methodOk = append(methodOk, each)
}
}
if len(methodOk) == 0 {
if trace {
traceLogger.Printf("no Route found (in %d routes) that matches HTTP method %s\n", len(routes), httpRequest.Method)
}
return nil, NewError(http.StatusMethodNotAllowed, "405: Method Not Allowed")
}
inputMediaOk := methodOk
// content-type
contentType := httpRequest.Header.Get(HEADER_ContentType)
inputMediaOk = []Route{}
for _, each := range methodOk {
if each.matchesContentType(contentType) {
inputMediaOk = append(inputMediaOk, each)
}
}
if len(inputMediaOk) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that matches HTTP Content-Type: %s\n", len(methodOk), contentType)
}
return nil, NewError(http.StatusUnsupportedMediaType, "415: Unsupported Media Type")
}
// accept
outputMediaOk := []Route{}
accept := httpRequest.Header.Get(HEADER_Accept)
if len(accept) == 0 {
accept = "*/*"
}
for _, each := range inputMediaOk {
if each.matchesAccept(accept) {
outputMediaOk = append(outputMediaOk, each)
}
}
if len(outputMediaOk) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that matches HTTP Accept: %s\n", len(inputMediaOk), accept)
}
return nil, NewError(http.StatusNotAcceptable, "406: Not Acceptable")
}
// return r.bestMatchByMedia(outputMediaOk, contentType, accept), nil
return &outputMediaOk[0], nil
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2
// n/m > n/* > */*
func (r RouterJSR311) bestMatchByMedia(routes []Route, contentType string, accept string) *Route {
// TODO
return &routes[0]
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2 (step 2)
func (r RouterJSR311) selectRoutes(dispatcher *WebService, pathRemainder string) []Route {
filtered := &sortableRouteCandidates{}
for _, each := range dispatcher.Routes() {
pathExpr := each.pathExpr
matches := pathExpr.Matcher.FindStringSubmatch(pathRemainder)
if matches != nil {
lastMatch := matches[len(matches)-1]
if len(lastMatch) == 0 || lastMatch == "/" { // do not include if value is neither empty nor /.
filtered.candidates = append(filtered.candidates,
routeCandidate{each, len(matches) - 1, pathExpr.LiteralCount, pathExpr.VarCount})
}
}
}
if len(filtered.candidates) == 0 {
if trace {
traceLogger.Printf("WebService on path %s has no routes that match URL path remainder:%s\n", dispatcher.rootPath, pathRemainder)
}
return []Route{}
}
sort.Sort(sort.Reverse(filtered))
// select other routes from candidates whoes expression matches rmatch
matchingRoutes := []Route{filtered.candidates[0].route}
for c := 1; c < len(filtered.candidates); c++ {
each := filtered.candidates[c]
if each.route.pathExpr.Matcher.MatchString(pathRemainder) {
matchingRoutes = append(matchingRoutes, each.route)
}
}
return matchingRoutes
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2 (step 1)
func (r RouterJSR311) detectDispatcher(requestPath string, dispatchers []*WebService) (*WebService, string, error) {
filtered := &sortableDispatcherCandidates{}
for _, each := range dispatchers {
matches := each.pathExpr.Matcher.FindStringSubmatch(requestPath)
if matches != nil {
filtered.candidates = append(filtered.candidates,
dispatcherCandidate{each, matches[len(matches)-1], len(matches), each.pathExpr.LiteralCount, each.pathExpr.VarCount})
}
}
if len(filtered.candidates) == 0 {
if trace {
traceLogger.Printf("no WebService was found to match URL path:%s\n", requestPath)
}
return nil, "", errors.New("not found")
}
sort.Sort(sort.Reverse(filtered))
return filtered.candidates[0].dispatcher, filtered.candidates[0].finalMatch, nil
}
// Types and functions to support the sorting of Routes
type routeCandidate struct {
route Route
matchesCount int // the number of capturing groups
literalCount int // the number of literal characters (means those not resulting from template variable substitution)
nonDefaultCount int // the number of capturing groups with non-default regular expressions (i.e. not ([^ /]+?))
}
func (r routeCandidate) expressionToMatch() string {
return r.route.pathExpr.Source
}
func (r routeCandidate) String() string {
return fmt.Sprintf("(m=%d,l=%d,n=%d)", r.matchesCount, r.literalCount, r.nonDefaultCount)
}
type sortableRouteCandidates struct {
candidates []routeCandidate
}
func (rcs *sortableRouteCandidates) Len() int {
return len(rcs.candidates)
}
func (rcs *sortableRouteCandidates) Swap(i, j int) {
rcs.candidates[i], rcs.candidates[j] = rcs.candidates[j], rcs.candidates[i]
}
func (rcs *sortableRouteCandidates) Less(i, j int) bool {
ci := rcs.candidates[i]
cj := rcs.candidates[j]
// primary key
if ci.literalCount < cj.literalCount {
return true
}
if ci.literalCount > cj.literalCount {
return false
}
// secundary key
if ci.matchesCount < cj.matchesCount {
return true
}
if ci.matchesCount > cj.matchesCount {
return false
}
// tertiary key
if ci.nonDefaultCount < cj.nonDefaultCount {
return true
}
if ci.nonDefaultCount > cj.nonDefaultCount {
return false
}
// quaternary key ("source" is interpreted as Path)
return ci.route.Path < cj.route.Path
}
// Types and functions to support the sorting of Dispatchers
type dispatcherCandidate struct {
dispatcher *WebService
finalMatch string
matchesCount int // the number of capturing groups
literalCount int // the number of literal characters (means those not resulting from template variable substitution)
nonDefaultCount int // the number of capturing groups with non-default regular expressions (i.e. not ([^ /]+?))
}
type sortableDispatcherCandidates struct {
candidates []dispatcherCandidate
}
func (dc *sortableDispatcherCandidates) Len() int {
return len(dc.candidates)
}
func (dc *sortableDispatcherCandidates) Swap(i, j int) {
dc.candidates[i], dc.candidates[j] = dc.candidates[j], dc.candidates[i]
}
func (dc *sortableDispatcherCandidates) Less(i, j int) bool {
ci := dc.candidates[i]
cj := dc.candidates[j]
// primary key
if ci.matchesCount < cj.matchesCount {
return true
}
if ci.matchesCount > cj.matchesCount {
return false
}
// secundary key
if ci.literalCount < cj.literalCount {
return true
}
if ci.literalCount > cj.literalCount {
return false
}
// tertiary key
return ci.nonDefaultCount < cj.nonDefaultCount
}

View File

@ -1,34 +0,0 @@
package log
import (
stdlog "log"
"os"
)
// StdLogger corresponds to a minimal subset of the interface satisfied by stdlib log.Logger
type StdLogger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
}
var Logger StdLogger
func init() {
// default Logger
SetLogger(stdlog.New(os.Stderr, "[restful] ", stdlog.LstdFlags|stdlog.Lshortfile))
}
// SetLogger sets the logger for this package
func SetLogger(customLogger StdLogger) {
Logger = customLogger
}
// Print delegates to the Logger
func Print(v ...interface{}) {
Logger.Print(v...)
}
// Printf delegates to the Logger
func Printf(format string, v ...interface{}) {
Logger.Printf(format, v...)
}

View File

@ -1,32 +0,0 @@
package restful
// Copyright 2014 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"github.com/emicklei/go-restful/log"
)
var trace bool = false
var traceLogger log.StdLogger
func init() {
traceLogger = log.Logger // use the package logger by default
}
// TraceLogger enables detailed logging of Http request matching and filter invocation. Default no logger is set.
// You may call EnableTracing() directly to enable trace logging to the package-wide logger.
func TraceLogger(logger log.StdLogger) {
traceLogger = logger
EnableTracing(logger != nil)
}
// SetLogger exposes the setter for the global logger on the top-level package
func SetLogger(customLogger log.StdLogger) {
log.SetLogger(customLogger)
}
// EnableTracing can be used to Trace logging on and off.
func EnableTracing(enabled bool) {
trace = enabled
}

View File

@ -1,45 +0,0 @@
package restful
import (
"strconv"
"strings"
)
type mime struct {
media string
quality float64
}
// insertMime adds a mime to a list and keeps it sorted by quality.
func insertMime(l []mime, e mime) []mime {
for i, each := range l {
// if current mime has lower quality then insert before
if e.quality > each.quality {
left := append([]mime{}, l[0:i]...)
return append(append(left, e), l[i:]...)
}
}
return append(l, e)
}
// sortedMimes returns a list of mime sorted (desc) by its specified quality.
func sortedMimes(accept string) (sorted []mime) {
for _, each := range strings.Split(accept, ",") {
typeAndQuality := strings.Split(strings.Trim(each, " "), ";")
if len(typeAndQuality) == 1 {
sorted = insertMime(sorted, mime{typeAndQuality[0], 1.0})
} else {
// take factor
parts := strings.Split(typeAndQuality[1], "=")
if len(parts) == 2 {
f, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
traceLogger.Printf("unable to parse quality in %s, %v", each, err)
} else {
sorted = insertMime(sorted, mime{typeAndQuality[0], f})
}
}
}
}
return
}

View File

@ -1,34 +0,0 @@
package restful
import "strings"
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
// OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method
// and provides the response with a set of allowed methods for the request URL Path.
// As for any filter, you can also install it for a particular WebService within a Container.
// Note: this filter is not needed when using CrossOriginResourceSharing (for CORS).
func (c *Container) OPTIONSFilter(req *Request, resp *Response, chain *FilterChain) {
if "OPTIONS" != req.Request.Method {
chain.ProcessFilter(req, resp)
return
}
archs := req.Request.Header.Get(HEADER_AccessControlRequestHeaders)
methods := strings.Join(c.computeAllowedMethods(req), ",")
origin := req.Request.Header.Get(HEADER_Origin)
resp.AddHeader(HEADER_Allow, methods)
resp.AddHeader(HEADER_AccessControlAllowOrigin, origin)
resp.AddHeader(HEADER_AccessControlAllowHeaders, archs)
resp.AddHeader(HEADER_AccessControlAllowMethods, methods)
}
// OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method
// and provides the response with a set of allowed methods for the request URL Path.
// Note: this filter is not needed when using CrossOriginResourceSharing (for CORS).
func OPTIONSFilter() FilterFunction {
return DefaultContainer.OPTIONSFilter
}

View File

@ -1,143 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
const (
// PathParameterKind = indicator of Request parameter type "path"
PathParameterKind = iota
// QueryParameterKind = indicator of Request parameter type "query"
QueryParameterKind
// BodyParameterKind = indicator of Request parameter type "body"
BodyParameterKind
// HeaderParameterKind = indicator of Request parameter type "header"
HeaderParameterKind
// FormParameterKind = indicator of Request parameter type "form"
FormParameterKind
// CollectionFormatCSV comma separated values `foo,bar`
CollectionFormatCSV = CollectionFormat("csv")
// CollectionFormatSSV space separated values `foo bar`
CollectionFormatSSV = CollectionFormat("ssv")
// CollectionFormatTSV tab separated values `foo\tbar`
CollectionFormatTSV = CollectionFormat("tsv")
// CollectionFormatPipes pipe separated values `foo|bar`
CollectionFormatPipes = CollectionFormat("pipes")
// CollectionFormatMulti corresponds to multiple parameter instances instead of multiple values for a single
// instance `foo=bar&foo=baz`. This is valid only for QueryParameters and FormParameters
CollectionFormatMulti = CollectionFormat("multi")
)
type CollectionFormat string
func (cf CollectionFormat) String() string {
return string(cf)
}
// Parameter is for documententing the parameter used in a Http Request
// ParameterData kinds are Path,Query and Body
type Parameter struct {
data *ParameterData
}
// ParameterData represents the state of a Parameter.
// It is made public to make it accessible to e.g. the Swagger package.
type ParameterData struct {
Name, Description, DataType, DataFormat string
Kind int
Required bool
AllowableValues map[string]string
AllowMultiple bool
DefaultValue string
CollectionFormat string
}
// Data returns the state of the Parameter
func (p *Parameter) Data() ParameterData {
return *p.data
}
// Kind returns the parameter type indicator (see const for valid values)
func (p *Parameter) Kind() int {
return p.data.Kind
}
func (p *Parameter) bePath() *Parameter {
p.data.Kind = PathParameterKind
return p
}
func (p *Parameter) beQuery() *Parameter {
p.data.Kind = QueryParameterKind
return p
}
func (p *Parameter) beBody() *Parameter {
p.data.Kind = BodyParameterKind
return p
}
func (p *Parameter) beHeader() *Parameter {
p.data.Kind = HeaderParameterKind
return p
}
func (p *Parameter) beForm() *Parameter {
p.data.Kind = FormParameterKind
return p
}
// Required sets the required field and returns the receiver
func (p *Parameter) Required(required bool) *Parameter {
p.data.Required = required
return p
}
// AllowMultiple sets the allowMultiple field and returns the receiver
func (p *Parameter) AllowMultiple(multiple bool) *Parameter {
p.data.AllowMultiple = multiple
return p
}
// AllowableValues sets the allowableValues field and returns the receiver
func (p *Parameter) AllowableValues(values map[string]string) *Parameter {
p.data.AllowableValues = values
return p
}
// DataType sets the dataType field and returns the receiver
func (p *Parameter) DataType(typeName string) *Parameter {
p.data.DataType = typeName
return p
}
// DataFormat sets the dataFormat field for Swagger UI
func (p *Parameter) DataFormat(formatName string) *Parameter {
p.data.DataFormat = formatName
return p
}
// DefaultValue sets the default value field and returns the receiver
func (p *Parameter) DefaultValue(stringRepresentation string) *Parameter {
p.data.DefaultValue = stringRepresentation
return p
}
// Description sets the description value field and returns the receiver
func (p *Parameter) Description(doc string) *Parameter {
p.data.Description = doc
return p
}
// CollectionFormat sets the collection format for an array type
func (p *Parameter) CollectionFormat(format CollectionFormat) *Parameter {
p.data.CollectionFormat = format.String()
return p
}

View File

@ -1,74 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"bytes"
"fmt"
"regexp"
"strings"
)
// PathExpression holds a compiled path expression (RegExp) needed to match against
// Http request paths and to extract path parameter values.
type pathExpression struct {
LiteralCount int // the number of literal characters (means those not resulting from template variable substitution)
VarNames []string // the names of parameters (enclosed by {}) in the path
VarCount int // the number of named parameters (enclosed by {}) in the path
Matcher *regexp.Regexp
Source string // Path as defined by the RouteBuilder
tokens []string
}
// NewPathExpression creates a PathExpression from the input URL path.
// Returns an error if the path is invalid.
func newPathExpression(path string) (*pathExpression, error) {
expression, literalCount, varNames, varCount, tokens := templateToRegularExpression(path)
compiled, err := regexp.Compile(expression)
if err != nil {
return nil, err
}
return &pathExpression{literalCount, varNames, varCount, compiled, expression, tokens}, nil
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-370003.7.3
func templateToRegularExpression(template string) (expression string, literalCount int, varNames []string, varCount int, tokens []string) {
var buffer bytes.Buffer
buffer.WriteString("^")
//tokens = strings.Split(template, "/")
tokens = tokenizePath(template)
for _, each := range tokens {
if each == "" {
continue
}
buffer.WriteString("/")
if strings.HasPrefix(each, "{") {
// check for regular expression in variable
colon := strings.Index(each, ":")
var varName string
if colon != -1 {
// extract expression
varName = strings.TrimSpace(each[1:colon])
paramExpr := strings.TrimSpace(each[colon+1 : len(each)-1])
if paramExpr == "*" { // special case
buffer.WriteString("(.*)")
} else {
buffer.WriteString(fmt.Sprintf("(%s)", paramExpr)) // between colon and closing moustache
}
} else {
// plain var
varName = strings.TrimSpace(each[1 : len(each)-1])
buffer.WriteString("([^/]+?)")
}
varNames = append(varNames, varName)
varCount += 1
} else {
literalCount += len(each)
encoded := each // TODO URI encode
buffer.WriteString(regexp.QuoteMeta(encoded))
}
}
return strings.TrimRight(buffer.String(), "/") + "(/.*)?$", literalCount, varNames, varCount, tokens
}

View File

@ -1,63 +0,0 @@
package restful
import (
"bytes"
"strings"
)
// Copyright 2018 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
// PathProcessor is extra behaviour that a Router can provide to extract path parameters from the path.
// If a Router does not implement this interface then the default behaviour will be used.
type PathProcessor interface {
// ExtractParameters gets the path parameters defined in the route and webService from the urlPath
ExtractParameters(route *Route, webService *WebService, urlPath string) map[string]string
}
type defaultPathProcessor struct{}
// Extract the parameters from the request url path
func (d defaultPathProcessor) ExtractParameters(r *Route, _ *WebService, urlPath string) map[string]string {
urlParts := tokenizePath(urlPath)
pathParameters := map[string]string{}
for i, key := range r.pathParts {
var value string
if i >= len(urlParts) {
value = ""
} else {
value = urlParts[i]
}
if strings.HasPrefix(key, "{") { // path-parameter
if colon := strings.Index(key, ":"); colon != -1 {
// extract by regex
regPart := key[colon+1 : len(key)-1]
keyPart := key[1:colon]
if regPart == "*" {
pathParameters[keyPart] = untokenizePath(i, urlParts)
break
} else {
pathParameters[keyPart] = value
}
} else {
// without enclosing {}
pathParameters[key[1:len(key)-1]] = value
}
}
}
return pathParameters
}
// Untokenize back into an URL path using the slash separator
func untokenizePath(offset int, parts []string) string {
var buffer bytes.Buffer
for p := offset; p < len(parts); p++ {
buffer.WriteString(parts[p])
// do not end
if p < len(parts)-1 {
buffer.WriteString("/")
}
}
return buffer.String()
}

View File

@ -1,118 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"compress/zlib"
"net/http"
)
var defaultRequestContentType string
// Request is a wrapper for a http Request that provides convenience methods
type Request struct {
Request *http.Request
pathParameters map[string]string
attributes map[string]interface{} // for storing request-scoped values
selectedRoutePath string // root path + route path that matched the request, e.g. /meetings/{id}/attendees
}
func NewRequest(httpRequest *http.Request) *Request {
return &Request{
Request: httpRequest,
pathParameters: map[string]string{},
attributes: map[string]interface{}{},
} // empty parameters, attributes
}
// If ContentType is missing or */* is given then fall back to this type, otherwise
// a "Unable to unmarshal content of type:" response is returned.
// Valid values are restful.MIME_JSON and restful.MIME_XML
// Example:
// restful.DefaultRequestContentType(restful.MIME_JSON)
func DefaultRequestContentType(mime string) {
defaultRequestContentType = mime
}
// PathParameter accesses the Path parameter value by its name
func (r *Request) PathParameter(name string) string {
return r.pathParameters[name]
}
// PathParameters accesses the Path parameter values
func (r *Request) PathParameters() map[string]string {
return r.pathParameters
}
// QueryParameter returns the (first) Query parameter value by its name
func (r *Request) QueryParameter(name string) string {
return r.Request.FormValue(name)
}
// QueryParameters returns the all the query parameters values by name
func (r *Request) QueryParameters(name string) []string {
return r.Request.URL.Query()[name]
}
// BodyParameter parses the body of the request (once for typically a POST or a PUT) and returns the value of the given name or an error.
func (r *Request) BodyParameter(name string) (string, error) {
err := r.Request.ParseForm()
if err != nil {
return "", err
}
return r.Request.PostFormValue(name), nil
}
// HeaderParameter returns the HTTP Header value of a Header name or empty if missing
func (r *Request) HeaderParameter(name string) string {
return r.Request.Header.Get(name)
}
// ReadEntity checks the Accept header and reads the content into the entityPointer.
func (r *Request) ReadEntity(entityPointer interface{}) (err error) {
contentType := r.Request.Header.Get(HEADER_ContentType)
contentEncoding := r.Request.Header.Get(HEADER_ContentEncoding)
// check if the request body needs decompression
if ENCODING_GZIP == contentEncoding {
gzipReader := currentCompressorProvider.AcquireGzipReader()
defer currentCompressorProvider.ReleaseGzipReader(gzipReader)
gzipReader.Reset(r.Request.Body)
r.Request.Body = gzipReader
} else if ENCODING_DEFLATE == contentEncoding {
zlibReader, err := zlib.NewReader(r.Request.Body)
if err != nil {
return err
}
r.Request.Body = zlibReader
}
// lookup the EntityReader, use defaultRequestContentType if needed and provided
entityReader, ok := entityAccessRegistry.accessorAt(contentType)
if !ok {
if len(defaultRequestContentType) != 0 {
entityReader, ok = entityAccessRegistry.accessorAt(defaultRequestContentType)
}
if !ok {
return NewError(http.StatusBadRequest, "Unable to unmarshal content of type:"+contentType)
}
}
return entityReader.Read(r, entityPointer)
}
// SetAttribute adds or replaces the attribute with the given value.
func (r *Request) SetAttribute(name string, value interface{}) {
r.attributes[name] = value
}
// Attribute returns the value associated to the given name. Returns nil if absent.
func (r Request) Attribute(name string) interface{} {
return r.attributes[name]
}
// SelectedRoutePath root path + route path that matched the request, e.g. /meetings/{id}/attendees
func (r Request) SelectedRoutePath() string {
return r.selectedRoutePath
}

View File

@ -1,250 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"bufio"
"errors"
"net"
"net/http"
)
// DefaultResponseMimeType is DEPRECATED, use DefaultResponseContentType(mime)
var DefaultResponseMimeType string
//PrettyPrintResponses controls the indentation feature of XML and JSON serialization
var PrettyPrintResponses = true
// Response is a wrapper on the actual http ResponseWriter
// It provides several convenience methods to prepare and write response content.
type Response struct {
http.ResponseWriter
requestAccept string // mime-type what the Http Request says it wants to receive
routeProduces []string // mime-types what the Route says it can produce
statusCode int // HTTP status code that has been written explicitly (if zero then net/http has written 200)
contentLength int // number of bytes written for the response body
prettyPrint bool // controls the indentation feature of XML and JSON serialization. It is initialized using var PrettyPrintResponses.
err error // err property is kept when WriteError is called
hijacker http.Hijacker // if underlying ResponseWriter supports it
}
// NewResponse creates a new response based on a http ResponseWriter.
func NewResponse(httpWriter http.ResponseWriter) *Response {
hijacker, _ := httpWriter.(http.Hijacker)
return &Response{ResponseWriter: httpWriter, routeProduces: []string{}, statusCode: http.StatusOK, prettyPrint: PrettyPrintResponses, hijacker: hijacker}
}
// DefaultResponseContentType set a default.
// If Accept header matching fails, fall back to this type.
// Valid values are restful.MIME_JSON and restful.MIME_XML
// Example:
// restful.DefaultResponseContentType(restful.MIME_JSON)
func DefaultResponseContentType(mime string) {
DefaultResponseMimeType = mime
}
// InternalServerError writes the StatusInternalServerError header.
// DEPRECATED, use WriteErrorString(http.StatusInternalServerError,reason)
func (r Response) InternalServerError() Response {
r.WriteHeader(http.StatusInternalServerError)
return r
}
// Hijack implements the http.Hijacker interface. This expands
// the Response to fulfill http.Hijacker if the underlying
// http.ResponseWriter supports it.
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if r.hijacker == nil {
return nil, nil, errors.New("http.Hijacker not implemented by underlying http.ResponseWriter")
}
return r.hijacker.Hijack()
}
// PrettyPrint changes whether this response must produce pretty (line-by-line, indented) JSON or XML output.
func (r *Response) PrettyPrint(bePretty bool) {
r.prettyPrint = bePretty
}
// AddHeader is a shortcut for .Header().Add(header,value)
func (r Response) AddHeader(header string, value string) Response {
r.Header().Add(header, value)
return r
}
// SetRequestAccepts tells the response what Mime-type(s) the HTTP request said it wants to accept. Exposed for testing.
func (r *Response) SetRequestAccepts(mime string) {
r.requestAccept = mime
}
// EntityWriter returns the registered EntityWriter that the entity (requested resource)
// can write according to what the request wants (Accept) and what the Route can produce or what the restful defaults say.
// If called before WriteEntity and WriteHeader then a false return value can be used to write a 406: Not Acceptable.
func (r *Response) EntityWriter() (EntityReaderWriter, bool) {
sorted := sortedMimes(r.requestAccept)
for _, eachAccept := range sorted {
for _, eachProduce := range r.routeProduces {
if eachProduce == eachAccept.media {
if w, ok := entityAccessRegistry.accessorAt(eachAccept.media); ok {
return w, true
}
}
}
if eachAccept.media == "*/*" {
for _, each := range r.routeProduces {
if w, ok := entityAccessRegistry.accessorAt(each); ok {
return w, true
}
}
}
}
// if requestAccept is empty
writer, ok := entityAccessRegistry.accessorAt(r.requestAccept)
if !ok {
// if not registered then fallback to the defaults (if set)
if DefaultResponseMimeType == MIME_JSON {
return entityAccessRegistry.accessorAt(MIME_JSON)
}
if DefaultResponseMimeType == MIME_XML {
return entityAccessRegistry.accessorAt(MIME_XML)
}
// Fallback to whatever the route says it can produce.
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
for _, each := range r.routeProduces {
if w, ok := entityAccessRegistry.accessorAt(each); ok {
return w, true
}
}
if trace {
traceLogger.Printf("no registered EntityReaderWriter found for %s", r.requestAccept)
}
}
return writer, ok
}
// WriteEntity calls WriteHeaderAndEntity with Http Status OK (200)
func (r *Response) WriteEntity(value interface{}) error {
return r.WriteHeaderAndEntity(http.StatusOK, value)
}
// WriteHeaderAndEntity marshals the value using the representation denoted by the Accept Header and the registered EntityWriters.
// If no Accept header is specified (or */*) then respond with the Content-Type as specified by the first in the Route.Produces.
// If an Accept header is specified then respond with the Content-Type as specified by the first in the Route.Produces that is matched with the Accept header.
// If the value is nil then no response is send except for the Http status. You may want to call WriteHeader(http.StatusNotFound) instead.
// If there is no writer available that can represent the value in the requested MIME type then Http Status NotAcceptable is written.
// Current implementation ignores any q-parameters in the Accept Header.
// Returns an error if the value could not be written on the response.
func (r *Response) WriteHeaderAndEntity(status int, value interface{}) error {
writer, ok := r.EntityWriter()
if !ok {
r.WriteHeader(http.StatusNotAcceptable)
return nil
}
return writer.Write(r, status, value)
}
// WriteAsXml is a convenience method for writing a value in xml (requires Xml tags on the value)
// It uses the standard encoding/xml package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteAsXml(value interface{}) error {
return writeXML(r, http.StatusOK, MIME_XML, value)
}
// WriteHeaderAndXml is a convenience method for writing a status and value in xml (requires Xml tags on the value)
// It uses the standard encoding/xml package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteHeaderAndXml(status int, value interface{}) error {
return writeXML(r, status, MIME_XML, value)
}
// WriteAsJson is a convenience method for writing a value in json.
// It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteAsJson(value interface{}) error {
return writeJSON(r, http.StatusOK, MIME_JSON, value)
}
// WriteJson is a convenience method for writing a value in Json with a given Content-Type.
// It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteJson(value interface{}, contentType string) error {
return writeJSON(r, http.StatusOK, contentType, value)
}
// WriteHeaderAndJson is a convenience method for writing the status and a value in Json with a given Content-Type.
// It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteHeaderAndJson(status int, value interface{}, contentType string) error {
return writeJSON(r, status, contentType, value)
}
// WriteError write the http status and the error string on the response.
func (r *Response) WriteError(httpStatus int, err error) error {
r.err = err
return r.WriteErrorString(httpStatus, err.Error())
}
// WriteServiceError is a convenience method for a responding with a status and a ServiceError
func (r *Response) WriteServiceError(httpStatus int, err ServiceError) error {
r.err = err
return r.WriteHeaderAndEntity(httpStatus, err)
}
// WriteErrorString is a convenience method for an error status with the actual error
func (r *Response) WriteErrorString(httpStatus int, errorReason string) error {
if r.err == nil {
// if not called from WriteError
r.err = errors.New(errorReason)
}
r.WriteHeader(httpStatus)
if _, err := r.Write([]byte(errorReason)); err != nil {
return err
}
return nil
}
// Flush implements http.Flusher interface, which sends any buffered data to the client.
func (r *Response) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else if trace {
traceLogger.Printf("ResponseWriter %v doesn't support Flush", r)
}
}
// WriteHeader is overridden to remember the Status Code that has been written.
// Changes to the Header of the response have no effect after this.
func (r *Response) WriteHeader(httpStatus int) {
r.statusCode = httpStatus
r.ResponseWriter.WriteHeader(httpStatus)
}
// StatusCode returns the code that has been written using WriteHeader.
func (r Response) StatusCode() int {
if 0 == r.statusCode {
// no status code has been written yet; assume OK
return http.StatusOK
}
return r.statusCode
}
// Write writes the data to the connection as part of an HTTP reply.
// Write is part of http.ResponseWriter interface.
func (r *Response) Write(bytes []byte) (int, error) {
written, err := r.ResponseWriter.Write(bytes)
r.contentLength += written
return written, err
}
// ContentLength returns the number of bytes written for the response content.
// Note that this value is only correct if all data is written through the Response using its Write* methods.
// Data written directly using the underlying http.ResponseWriter is not accounted for.
func (r Response) ContentLength() int {
return r.contentLength
}
// CloseNotify is part of http.CloseNotifier interface
func (r Response) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Error returns the err created by WriteError
func (r Response) Error() error {
return r.err
}

View File

@ -1,149 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"net/http"
"strings"
)
// RouteFunction declares the signature of a function that can be bound to a Route.
type RouteFunction func(*Request, *Response)
// RouteSelectionConditionFunction declares the signature of a function that
// can be used to add extra conditional logic when selecting whether the route
// matches the HTTP request.
type RouteSelectionConditionFunction func(httpRequest *http.Request) bool
// Route binds a HTTP Method,Path,Consumes combination to a RouteFunction.
type Route struct {
Method string
Produces []string
Consumes []string
Path string // webservice root path + described path
Function RouteFunction
Filters []FilterFunction
If []RouteSelectionConditionFunction
// cached values for dispatching
relativePath string
pathParts []string
pathExpr *pathExpression // cached compilation of relativePath as RegExp
// documentation
Doc string
Notes string
Operation string
ParameterDocs []*Parameter
ResponseErrors map[int]ResponseError
ReadSample, WriteSample interface{} // structs that model an example request or response payload
// Extra information used to store custom information about the route.
Metadata map[string]interface{}
// marks a route as deprecated
Deprecated bool
}
// Initialize for Route
func (r *Route) postBuild() {
r.pathParts = tokenizePath(r.Path)
}
// Create Request and Response from their http versions
func (r *Route) wrapRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request, pathParams map[string]string) (*Request, *Response) {
wrappedRequest := NewRequest(httpRequest)
wrappedRequest.pathParameters = pathParams
wrappedRequest.selectedRoutePath = r.Path
wrappedResponse := NewResponse(httpWriter)
wrappedResponse.requestAccept = httpRequest.Header.Get(HEADER_Accept)
wrappedResponse.routeProduces = r.Produces
return wrappedRequest, wrappedResponse
}
// dispatchWithFilters call the function after passing through its own filters
func (r *Route) dispatchWithFilters(wrappedRequest *Request, wrappedResponse *Response) {
if len(r.Filters) > 0 {
chain := FilterChain{Filters: r.Filters, Target: r.Function}
chain.ProcessFilter(wrappedRequest, wrappedResponse)
} else {
// unfiltered
r.Function(wrappedRequest, wrappedResponse)
}
}
// Return whether the mimeType matches to what this Route can produce.
func (r Route) matchesAccept(mimeTypesWithQuality string) bool {
parts := strings.Split(mimeTypesWithQuality, ",")
for _, each := range parts {
var withoutQuality string
if strings.Contains(each, ";") {
withoutQuality = strings.Split(each, ";")[0]
} else {
withoutQuality = each
}
// trim before compare
withoutQuality = strings.Trim(withoutQuality, " ")
if withoutQuality == "*/*" {
return true
}
for _, producibleType := range r.Produces {
if producibleType == "*/*" || producibleType == withoutQuality {
return true
}
}
}
return false
}
// Return whether this Route can consume content with a type specified by mimeTypes (can be empty).
func (r Route) matchesContentType(mimeTypes string) bool {
if len(r.Consumes) == 0 {
// did not specify what it can consume ; any media type (“*/*”) is assumed
return true
}
if len(mimeTypes) == 0 {
// idempotent methods with (most-likely or guaranteed) empty content match missing Content-Type
m := r.Method
if m == "GET" || m == "HEAD" || m == "OPTIONS" || m == "DELETE" || m == "TRACE" {
return true
}
// proceed with default
mimeTypes = MIME_OCTET
}
parts := strings.Split(mimeTypes, ",")
for _, each := range parts {
var contentType string
if strings.Contains(each, ";") {
contentType = strings.Split(each, ";")[0]
} else {
contentType = each
}
// trim before compare
contentType = strings.Trim(contentType, " ")
for _, consumeableType := range r.Consumes {
if consumeableType == "*/*" || consumeableType == contentType {
return true
}
}
}
return false
}
// Tokenize an URL path using the slash separator ; the result does not have empty tokens
func tokenizePath(path string) []string {
if "/" == path {
return []string{}
}
return strings.Split(strings.Trim(path, "/"), "/")
}
// for debugging
func (r Route) String() string {
return r.Method + " " + r.Path
}

View File

@ -1,321 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"fmt"
"os"
"reflect"
"runtime"
"strings"
"sync/atomic"
"github.com/emicklei/go-restful/log"
)
// RouteBuilder is a helper to construct Routes.
type RouteBuilder struct {
rootPath string
currentPath string
produces []string
consumes []string
httpMethod string // required
function RouteFunction // required
filters []FilterFunction
conditions []RouteSelectionConditionFunction
typeNameHandleFunc TypeNameHandleFunction // required
// documentation
doc string
notes string
operation string
readSample, writeSample interface{}
parameters []*Parameter
errorMap map[int]ResponseError
metadata map[string]interface{}
deprecated bool
}
// Do evaluates each argument with the RouteBuilder itself.
// This allows you to follow DRY principles without breaking the fluent programming style.
// Example:
// ws.Route(ws.DELETE("/{name}").To(t.deletePerson).Do(Returns200, Returns500))
//
// func Returns500(b *RouteBuilder) {
// b.Returns(500, "Internal Server Error", restful.ServiceError{})
// }
func (b *RouteBuilder) Do(oneArgBlocks ...func(*RouteBuilder)) *RouteBuilder {
for _, each := range oneArgBlocks {
each(b)
}
return b
}
// To bind the route to a function.
// If this route is matched with the incoming Http Request then call this function with the *Request,*Response pair. Required.
func (b *RouteBuilder) To(function RouteFunction) *RouteBuilder {
b.function = function
return b
}
// Method specifies what HTTP method to match. Required.
func (b *RouteBuilder) Method(method string) *RouteBuilder {
b.httpMethod = method
return b
}
// Produces specifies what MIME types can be produced ; the matched one will appear in the Content-Type Http header.
func (b *RouteBuilder) Produces(mimeTypes ...string) *RouteBuilder {
b.produces = mimeTypes
return b
}
// Consumes specifies what MIME types can be consumes ; the Accept Http header must matched any of these
func (b *RouteBuilder) Consumes(mimeTypes ...string) *RouteBuilder {
b.consumes = mimeTypes
return b
}
// Path specifies the relative (w.r.t WebService root path) URL path to match. Default is "/".
func (b *RouteBuilder) Path(subPath string) *RouteBuilder {
b.currentPath = subPath
return b
}
// Doc tells what this route is all about. Optional.
func (b *RouteBuilder) Doc(documentation string) *RouteBuilder {
b.doc = documentation
return b
}
// Notes is a verbose explanation of the operation behavior. Optional.
func (b *RouteBuilder) Notes(notes string) *RouteBuilder {
b.notes = notes
return b
}
// Reads tells what resource type will be read from the request payload. Optional.
// A parameter of type "body" is added ,required is set to true and the dataType is set to the qualified name of the sample's type.
func (b *RouteBuilder) Reads(sample interface{}, optionalDescription ...string) *RouteBuilder {
fn := b.typeNameHandleFunc
if fn == nil {
fn = reflectTypeName
}
typeAsName := fn(sample)
description := ""
if len(optionalDescription) > 0 {
description = optionalDescription[0]
}
b.readSample = sample
bodyParameter := &Parameter{&ParameterData{Name: "body", Description: description}}
bodyParameter.beBody()
bodyParameter.Required(true)
bodyParameter.DataType(typeAsName)
b.Param(bodyParameter)
return b
}
// ParameterNamed returns a Parameter already known to the RouteBuilder. Returns nil if not.
// Use this to modify or extend information for the Parameter (through its Data()).
func (b RouteBuilder) ParameterNamed(name string) (p *Parameter) {
for _, each := range b.parameters {
if each.Data().Name == name {
return each
}
}
return p
}
// Writes tells what resource type will be written as the response payload. Optional.
func (b *RouteBuilder) Writes(sample interface{}) *RouteBuilder {
b.writeSample = sample
return b
}
// Param allows you to document the parameters of the Route. It adds a new Parameter (does not check for duplicates).
func (b *RouteBuilder) Param(parameter *Parameter) *RouteBuilder {
if b.parameters == nil {
b.parameters = []*Parameter{}
}
b.parameters = append(b.parameters, parameter)
return b
}
// Operation allows you to document what the actual method/function call is of the Route.
// Unless called, the operation name is derived from the RouteFunction set using To(..).
func (b *RouteBuilder) Operation(name string) *RouteBuilder {
b.operation = name
return b
}
// ReturnsError is deprecated, use Returns instead.
func (b *RouteBuilder) ReturnsError(code int, message string, model interface{}) *RouteBuilder {
log.Print("ReturnsError is deprecated, use Returns instead.")
return b.Returns(code, message, model)
}
// Returns allows you to document what responses (errors or regular) can be expected.
// The model parameter is optional ; either pass a struct instance or use nil if not applicable.
func (b *RouteBuilder) Returns(code int, message string, model interface{}) *RouteBuilder {
err := ResponseError{
Code: code,
Message: message,
Model: model,
IsDefault: false,
}
// lazy init because there is no NewRouteBuilder (yet)
if b.errorMap == nil {
b.errorMap = map[int]ResponseError{}
}
b.errorMap[code] = err
return b
}
// DefaultReturns is a special Returns call that sets the default of the response ; the code is zero.
func (b *RouteBuilder) DefaultReturns(message string, model interface{}) *RouteBuilder {
b.Returns(0, message, model)
// Modify the ResponseError just added/updated
re := b.errorMap[0]
// errorMap is initialized
b.errorMap[0] = ResponseError{
Code: re.Code,
Message: re.Message,
Model: re.Model,
IsDefault: true,
}
return b
}
// Metadata adds or updates a key=value pair to the metadata map.
func (b *RouteBuilder) Metadata(key string, value interface{}) *RouteBuilder {
if b.metadata == nil {
b.metadata = map[string]interface{}{}
}
b.metadata[key] = value
return b
}
// Deprecate sets the value of deprecated to true. Deprecated routes have a special UI treatment to warn against use
func (b *RouteBuilder) Deprecate() *RouteBuilder {
b.deprecated = true
return b
}
// ResponseError represents a response; not necessarily an error.
type ResponseError struct {
Code int
Message string
Model interface{}
IsDefault bool
}
func (b *RouteBuilder) servicePath(path string) *RouteBuilder {
b.rootPath = path
return b
}
// Filter appends a FilterFunction to the end of filters for this Route to build.
func (b *RouteBuilder) Filter(filter FilterFunction) *RouteBuilder {
b.filters = append(b.filters, filter)
return b
}
// If sets a condition function that controls matching the Route based on custom logic.
// The condition function is provided the HTTP request and should return true if the route
// should be considered.
//
// Efficiency note: the condition function is called before checking the method, produces, and
// consumes criteria, so that the correct HTTP status code can be returned.
//
// Lifecycle note: no filter functions have been called prior to calling the condition function,
// so the condition function should not depend on any context that might be set up by container
// or route filters.
func (b *RouteBuilder) If(condition RouteSelectionConditionFunction) *RouteBuilder {
b.conditions = append(b.conditions, condition)
return b
}
// If no specific Route path then set to rootPath
// If no specific Produces then set to rootProduces
// If no specific Consumes then set to rootConsumes
func (b *RouteBuilder) copyDefaults(rootProduces, rootConsumes []string) {
if len(b.produces) == 0 {
b.produces = rootProduces
}
if len(b.consumes) == 0 {
b.consumes = rootConsumes
}
}
// typeNameHandler sets the function that will convert types to strings in the parameter
// and model definitions.
func (b *RouteBuilder) typeNameHandler(handler TypeNameHandleFunction) *RouteBuilder {
b.typeNameHandleFunc = handler
return b
}
// Build creates a new Route using the specification details collected by the RouteBuilder
func (b *RouteBuilder) Build() Route {
pathExpr, err := newPathExpression(b.currentPath)
if err != nil {
log.Printf("Invalid path:%s because:%v", b.currentPath, err)
os.Exit(1)
}
if b.function == nil {
log.Printf("No function specified for route:" + b.currentPath)
os.Exit(1)
}
operationName := b.operation
if len(operationName) == 0 && b.function != nil {
// extract from definition
operationName = nameOfFunction(b.function)
}
route := Route{
Method: b.httpMethod,
Path: concatPath(b.rootPath, b.currentPath),
Produces: b.produces,
Consumes: b.consumes,
Function: b.function,
Filters: b.filters,
If: b.conditions,
relativePath: b.currentPath,
pathExpr: pathExpr,
Doc: b.doc,
Notes: b.notes,
Operation: operationName,
ParameterDocs: b.parameters,
ResponseErrors: b.errorMap,
ReadSample: b.readSample,
WriteSample: b.writeSample,
Metadata: b.metadata,
Deprecated: b.deprecated}
route.postBuild()
return route
}
func concatPath(path1, path2 string) string {
return strings.TrimRight(path1, "/") + "/" + strings.TrimLeft(path2, "/")
}
var anonymousFuncCount int32
// nameOfFunction returns the short name of the function f for documentation.
// It uses a runtime feature for debugging ; its value may change for later Go versions.
func nameOfFunction(f interface{}) string {
fun := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
tokenized := strings.Split(fun.Name(), ".")
last := tokenized[len(tokenized)-1]
last = strings.TrimSuffix(last, ")·fm") // < Go 1.5
last = strings.TrimSuffix(last, ")-fm") // Go 1.5
last = strings.TrimSuffix(last, "·fm") // < Go 1.5
last = strings.TrimSuffix(last, "-fm") // Go 1.5
if last == "func1" { // this could mean conflicts in API docs
val := atomic.AddInt32(&anonymousFuncCount, 1)
last = "func" + fmt.Sprintf("%d", val)
atomic.StoreInt32(&anonymousFuncCount, val)
}
return last
}

View File

@ -1,20 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import "net/http"
// A RouteSelector finds the best matching Route given the input HTTP Request
// RouteSelectors can optionally also implement the PathProcessor interface to also calculate the
// path parameters after the route has been selected.
type RouteSelector interface {
// SelectRoute finds a Route given the input HTTP Request and a list of WebServices.
// It returns a selected Route and its containing WebService or an error indicating
// a problem.
SelectRoute(
webServices []*WebService,
httpRequest *http.Request) (selectedService *WebService, selected *Route, err error)
}

View File

@ -1,23 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import "fmt"
// ServiceError is a transport object to pass information about a non-Http error occurred in a WebService while processing a request.
type ServiceError struct {
Code int
Message string
}
// NewError returns a ServiceError using the code and reason
func NewError(code int, message string) ServiceError {
return ServiceError{Code: code, Message: message}
}
// Error returns a text representation of the service error
func (s ServiceError) Error() string {
return fmt.Sprintf("[ServiceError:%v] %v", s.Code, s.Message)
}

View File

@ -1,290 +0,0 @@
package restful
import (
"errors"
"os"
"reflect"
"sync"
"github.com/emicklei/go-restful/log"
)
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
// WebService holds a collection of Route values that bind a Http Method + URL Path to a function.
type WebService struct {
rootPath string
pathExpr *pathExpression // cached compilation of rootPath as RegExp
routes []Route
produces []string
consumes []string
pathParameters []*Parameter
filters []FilterFunction
documentation string
apiVersion string
typeNameHandleFunc TypeNameHandleFunction
dynamicRoutes bool
// protects 'routes' if dynamic routes are enabled
routesLock sync.RWMutex
}
func (w *WebService) SetDynamicRoutes(enable bool) {
w.dynamicRoutes = enable
}
// TypeNameHandleFunction declares functions that can handle translating the name of a sample object
// into the restful documentation for the service.
type TypeNameHandleFunction func(sample interface{}) string
// TypeNameHandler sets the function that will convert types to strings in the parameter
// and model definitions. If not set, the web service will invoke
// reflect.TypeOf(object).String().
func (w *WebService) TypeNameHandler(handler TypeNameHandleFunction) *WebService {
w.typeNameHandleFunc = handler
return w
}
// reflectTypeName is the default TypeNameHandleFunction and for a given object
// returns the name that Go identifies it with (e.g. "string" or "v1.Object") via
// the reflection API.
func reflectTypeName(sample interface{}) string {
return reflect.TypeOf(sample).String()
}
// compilePathExpression ensures that the path is compiled into a RegEx for those routers that need it.
func (w *WebService) compilePathExpression() {
compiled, err := newPathExpression(w.rootPath)
if err != nil {
log.Printf("invalid path:%s because:%v", w.rootPath, err)
os.Exit(1)
}
w.pathExpr = compiled
}
// ApiVersion sets the API version for documentation purposes.
func (w *WebService) ApiVersion(apiVersion string) *WebService {
w.apiVersion = apiVersion
return w
}
// Version returns the API version for documentation purposes.
func (w *WebService) Version() string { return w.apiVersion }
// Path specifies the root URL template path of the WebService.
// All Routes will be relative to this path.
func (w *WebService) Path(root string) *WebService {
w.rootPath = root
if len(w.rootPath) == 0 {
w.rootPath = "/"
}
w.compilePathExpression()
return w
}
// Param adds a PathParameter to document parameters used in the root path.
func (w *WebService) Param(parameter *Parameter) *WebService {
if w.pathParameters == nil {
w.pathParameters = []*Parameter{}
}
w.pathParameters = append(w.pathParameters, parameter)
return w
}
// PathParameter creates a new Parameter of kind Path for documentation purposes.
// It is initialized as required with string as its DataType.
func (w *WebService) PathParameter(name, description string) *Parameter {
return PathParameter(name, description)
}
// PathParameter creates a new Parameter of kind Path for documentation purposes.
// It is initialized as required with string as its DataType.
func PathParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: true, DataType: "string"}}
p.bePath()
return p
}
// QueryParameter creates a new Parameter of kind Query for documentation purposes.
// It is initialized as not required with string as its DataType.
func (w *WebService) QueryParameter(name, description string) *Parameter {
return QueryParameter(name, description)
}
// QueryParameter creates a new Parameter of kind Query for documentation purposes.
// It is initialized as not required with string as its DataType.
func QueryParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: false, DataType: "string", CollectionFormat: CollectionFormatCSV.String()}}
p.beQuery()
return p
}
// BodyParameter creates a new Parameter of kind Body for documentation purposes.
// It is initialized as required without a DataType.
func (w *WebService) BodyParameter(name, description string) *Parameter {
return BodyParameter(name, description)
}
// BodyParameter creates a new Parameter of kind Body for documentation purposes.
// It is initialized as required without a DataType.
func BodyParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: true}}
p.beBody()
return p
}
// HeaderParameter creates a new Parameter of kind (Http) Header for documentation purposes.
// It is initialized as not required with string as its DataType.
func (w *WebService) HeaderParameter(name, description string) *Parameter {
return HeaderParameter(name, description)
}
// HeaderParameter creates a new Parameter of kind (Http) Header for documentation purposes.
// It is initialized as not required with string as its DataType.
func HeaderParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: false, DataType: "string"}}
p.beHeader()
return p
}
// FormParameter creates a new Parameter of kind Form (using application/x-www-form-urlencoded) for documentation purposes.
// It is initialized as required with string as its DataType.
func (w *WebService) FormParameter(name, description string) *Parameter {
return FormParameter(name, description)
}
// FormParameter creates a new Parameter of kind Form (using application/x-www-form-urlencoded) for documentation purposes.
// It is initialized as required with string as its DataType.
func FormParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: false, DataType: "string"}}
p.beForm()
return p
}
// Route creates a new Route using the RouteBuilder and add to the ordered list of Routes.
func (w *WebService) Route(builder *RouteBuilder) *WebService {
w.routesLock.Lock()
defer w.routesLock.Unlock()
builder.copyDefaults(w.produces, w.consumes)
w.routes = append(w.routes, builder.Build())
return w
}
// RemoveRoute removes the specified route, looks for something that matches 'path' and 'method'
func (w *WebService) RemoveRoute(path, method string) error {
if !w.dynamicRoutes {
return errors.New("dynamic routes are not enabled.")
}
w.routesLock.Lock()
defer w.routesLock.Unlock()
newRoutes := make([]Route, (len(w.routes) - 1))
current := 0
for ix := range w.routes {
if w.routes[ix].Method == method && w.routes[ix].Path == path {
continue
}
newRoutes[current] = w.routes[ix]
current = current + 1
}
w.routes = newRoutes
return nil
}
// Method creates a new RouteBuilder and initialize its http method
func (w *WebService) Method(httpMethod string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method(httpMethod)
}
// Produces specifies that this WebService can produce one or more MIME types.
// Http requests must have one of these values set for the Accept header.
func (w *WebService) Produces(contentTypes ...string) *WebService {
w.produces = contentTypes
return w
}
// Consumes specifies that this WebService can consume one or more MIME types.
// Http requests must have one of these values set for the Content-Type header.
func (w *WebService) Consumes(accepts ...string) *WebService {
w.consumes = accepts
return w
}
// Routes returns the Routes associated with this WebService
func (w *WebService) Routes() []Route {
if !w.dynamicRoutes {
return w.routes
}
// Make a copy of the array to prevent concurrency problems
w.routesLock.RLock()
defer w.routesLock.RUnlock()
result := make([]Route, len(w.routes))
for ix := range w.routes {
result[ix] = w.routes[ix]
}
return result
}
// RootPath returns the RootPath associated with this WebService. Default "/"
func (w *WebService) RootPath() string {
return w.rootPath
}
// PathParameters return the path parameter names for (shared among its Routes)
func (w *WebService) PathParameters() []*Parameter {
return w.pathParameters
}
// Filter adds a filter function to the chain of filters applicable to all its Routes
func (w *WebService) Filter(filter FilterFunction) *WebService {
w.filters = append(w.filters, filter)
return w
}
// Doc is used to set the documentation of this service.
func (w *WebService) Doc(plainText string) *WebService {
w.documentation = plainText
return w
}
// Documentation returns it.
func (w *WebService) Documentation() string {
return w.documentation
}
/*
Convenience methods
*/
// HEAD is a shortcut for .Method("HEAD").Path(subPath)
func (w *WebService) HEAD(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("HEAD").Path(subPath)
}
// GET is a shortcut for .Method("GET").Path(subPath)
func (w *WebService) GET(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("GET").Path(subPath)
}
// POST is a shortcut for .Method("POST").Path(subPath)
func (w *WebService) POST(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("POST").Path(subPath)
}
// PUT is a shortcut for .Method("PUT").Path(subPath)
func (w *WebService) PUT(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("PUT").Path(subPath)
}
// PATCH is a shortcut for .Method("PATCH").Path(subPath)
func (w *WebService) PATCH(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("PATCH").Path(subPath)
}
// DELETE is a shortcut for .Method("DELETE").Path(subPath)
func (w *WebService) DELETE(subPath string) *RouteBuilder {
return new(RouteBuilder).typeNameHandler(w.typeNameHandleFunc).servicePath(w.rootPath).Method("DELETE").Path(subPath)
}

View File

@ -1,39 +0,0 @@
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"net/http"
)
// DefaultContainer is a restful.Container that uses http.DefaultServeMux
var DefaultContainer *Container
func init() {
DefaultContainer = NewContainer()
DefaultContainer.ServeMux = http.DefaultServeMux
}
// If set the true then panics will not be caught to return HTTP 500.
// In that case, Route functions are responsible for handling any error situation.
// Default value is false = recover from panics. This has performance implications.
// OBSOLETE ; use restful.DefaultContainer.DoNotRecover(true)
var DoNotRecover = false
// Add registers a new WebService add it to the DefaultContainer.
func Add(service *WebService) {
DefaultContainer.Add(service)
}
// Filter appends a container FilterFunction from the DefaultContainer.
// These are called before dispatching a http.Request to a WebService.
func Filter(filter FilterFunction) {
DefaultContainer.Filter(filter)
}
// RegisteredWebServices returns the collections of WebServices from the DefaultContainer
func RegisteredWebServices() []*WebService {
return DefaultContainer.RegisteredWebServices()
}

View File

@ -1,41 +0,0 @@
Copyright (c) 2015, Emir Pasic
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-------------------------------------------------------------------------------
AVL Tree:
Copyright (c) 2017 Benjamin Scher Purcell <benjapurcell@gmail.com>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

View File

@ -1,35 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package containers provides core interfaces and functions for data structures.
//
// Container is the base interface for all data structures to implement.
//
// Iterators provide stateful iterators.
//
// Enumerable provides Ruby inspired (each, select, map, find, any?, etc.) container functions.
//
// Serialization provides serializers (marshalers) and deserializers (unmarshalers).
package containers
import "github.com/emirpasic/gods/utils"
// Container is base interface that all data structures implement.
type Container interface {
Empty() bool
Size() int
Clear()
Values() []interface{}
}
// GetSortedValues returns sorted container's elements with respect to the passed comparator.
// Does not effect the ordering of elements within the container.
func GetSortedValues(container Container, comparator utils.Comparator) []interface{} {
values := container.Values()
if len(values) < 2 {
return values
}
utils.Sort(values, comparator)
return values
}

View File

@ -1,61 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package containers
// EnumerableWithIndex provides functions for ordered containers whose values can be fetched by an index.
type EnumerableWithIndex interface {
// Each calls the given function once for each element, passing that element's index and value.
Each(func(index int, value interface{}))
// Map invokes the given function once for each element and returns a
// container containing the values returned by the given function.
// TODO need help on how to enforce this in containers (don't want to type assert when chaining)
// Map(func(index int, value interface{}) interface{}) Container
// Select returns a new container containing all elements for which the given function returns a true value.
// TODO need help on how to enforce this in containers (don't want to type assert when chaining)
// Select(func(index int, value interface{}) bool) Container
// Any passes each element of the container to the given function and
// returns true if the function ever returns true for any element.
Any(func(index int, value interface{}) bool) bool
// All passes each element of the container to the given function and
// returns true if the function returns true for all elements.
All(func(index int, value interface{}) bool) bool
// Find passes each element of the container to the given function and returns
// the first (index,value) for which the function is true or -1,nil otherwise
// if no element matches the criteria.
Find(func(index int, value interface{}) bool) (int, interface{})
}
// EnumerableWithKey provides functions for ordered containers whose values whose elements are key/value pairs.
type EnumerableWithKey interface {
// Each calls the given function once for each element, passing that element's key and value.
Each(func(key interface{}, value interface{}))
// Map invokes the given function once for each element and returns a container
// containing the values returned by the given function as key/value pairs.
// TODO need help on how to enforce this in containers (don't want to type assert when chaining)
// Map(func(key interface{}, value interface{}) (interface{}, interface{})) Container
// Select returns a new container containing all elements for which the given function returns a true value.
// TODO need help on how to enforce this in containers (don't want to type assert when chaining)
// Select(func(key interface{}, value interface{}) bool) Container
// Any passes each element of the container to the given function and
// returns true if the function ever returns true for any element.
Any(func(key interface{}, value interface{}) bool) bool
// All passes each element of the container to the given function and
// returns true if the function returns true for all elements.
All(func(key interface{}, value interface{}) bool) bool
// Find passes each element of the container to the given function and returns
// the first (key,value) for which the function is true or nil,nil otherwise if no element
// matches the criteria.
Find(func(key interface{}, value interface{}) bool) (interface{}, interface{})
}

View File

@ -1,109 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package containers
// IteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index.
type IteratorWithIndex interface {
// Next moves the iterator to the next element and returns true if there was a next element in the container.
// If Next() returns true, then next element's index and value can be retrieved by Index() and Value().
// If Next() was called for the first time, then it will point the iterator to the first element if it exists.
// Modifies the state of the iterator.
Next() bool
// Value returns the current element's value.
// Does not modify the state of the iterator.
Value() interface{}
// Index returns the current element's index.
// Does not modify the state of the iterator.
Index() int
// Begin resets the iterator to its initial state (one-before-first)
// Call Next() to fetch the first element if any.
Begin()
// First moves the iterator to the first element and returns true if there was a first element in the container.
// If First() returns true, then first element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
First() bool
}
// IteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs.
type IteratorWithKey interface {
// Next moves the iterator to the next element and returns true if there was a next element in the container.
// If Next() returns true, then next element's key and value can be retrieved by Key() and Value().
// If Next() was called for the first time, then it will point the iterator to the first element if it exists.
// Modifies the state of the iterator.
Next() bool
// Value returns the current element's value.
// Does not modify the state of the iterator.
Value() interface{}
// Key returns the current element's key.
// Does not modify the state of the iterator.
Key() interface{}
// Begin resets the iterator to its initial state (one-before-first)
// Call Next() to fetch the first element if any.
Begin()
// First moves the iterator to the first element and returns true if there was a first element in the container.
// If First() returns true, then first element's key and value can be retrieved by Key() and Value().
// Modifies the state of the iterator.
First() bool
}
// ReverseIteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index.
//
// Essentially it is the same as IteratorWithIndex, but provides additional:
//
// Prev() function to enable traversal in reverse
//
// Last() function to move the iterator to the last element.
//
// End() function to move the iterator past the last element (one-past-the-end).
type ReverseIteratorWithIndex interface {
// Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
// If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
Prev() bool
// End moves the iterator past the last element (one-past-the-end).
// Call Prev() to fetch the last element if any.
End()
// Last moves the iterator to the last element and returns true if there was a last element in the container.
// If Last() returns true, then last element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
Last() bool
IteratorWithIndex
}
// ReverseIteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs.
//
// Essentially it is the same as IteratorWithKey, but provides additional:
//
// Prev() function to enable traversal in reverse
//
// Last() function to move the iterator to the last element.
type ReverseIteratorWithKey interface {
// Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
// If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value().
// Modifies the state of the iterator.
Prev() bool
// End moves the iterator past the last element (one-past-the-end).
// Call Prev() to fetch the last element if any.
End()
// Last moves the iterator to the last element and returns true if there was a last element in the container.
// If Last() returns true, then last element's key and value can be retrieved by Key() and Value().
// Modifies the state of the iterator.
Last() bool
IteratorWithKey
}

View File

@ -1,17 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package containers
// JSONSerializer provides JSON serialization
type JSONSerializer interface {
// ToJSON outputs the JSON representation of containers's elements.
ToJSON() ([]byte, error)
}
// JSONDeserializer provides JSON deserialization
type JSONDeserializer interface {
// FromJSON populates containers's elements from the input JSON representation.
FromJSON([]byte) error
}

View File

@ -1,200 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package arraylist implements the array list.
//
// Structure is not thread safe.
//
// Reference: https://en.wikipedia.org/wiki/List_%28abstract_data_type%29
package arraylist
import (
"fmt"
"github.com/emirpasic/gods/lists"
"github.com/emirpasic/gods/utils"
"strings"
)
func assertListImplementation() {
var _ lists.List = (*List)(nil)
}
// List holds the elements in a slice
type List struct {
elements []interface{}
size int
}
const (
growthFactor = float32(2.0) // growth by 100%
shrinkFactor = float32(0.25) // shrink when size is 25% of capacity (0 means never shrink)
)
// New instantiates a new empty list
func New() *List {
return &List{}
}
// Add appends a value at the end of the list
func (list *List) Add(values ...interface{}) {
list.growBy(len(values))
for _, value := range values {
list.elements[list.size] = value
list.size++
}
}
// Get returns the element at index.
// Second return parameter is true if index is within bounds of the array and array is not empty, otherwise false.
func (list *List) Get(index int) (interface{}, bool) {
if !list.withinRange(index) {
return nil, false
}
return list.elements[index], true
}
// Remove removes one or more elements from the list with the supplied indices.
func (list *List) Remove(index int) {
if !list.withinRange(index) {
return
}
list.elements[index] = nil // cleanup reference
copy(list.elements[index:], list.elements[index+1:list.size]) // shift to the left by one (slow operation, need ways to optimize this)
list.size--
list.shrink()
}
// Contains checks if elements (one or more) are present in the set.
// All elements have to be present in the set for the method to return true.
// Performance time complexity of n^2.
// Returns true if no arguments are passed at all, i.e. set is always super-set of empty set.
func (list *List) Contains(values ...interface{}) bool {
for _, searchValue := range values {
found := false
for _, element := range list.elements {
if element == searchValue {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// Values returns all elements in the list.
func (list *List) Values() []interface{} {
newElements := make([]interface{}, list.size, list.size)
copy(newElements, list.elements[:list.size])
return newElements
}
// Empty returns true if list does not contain any elements.
func (list *List) Empty() bool {
return list.size == 0
}
// Size returns number of elements within the list.
func (list *List) Size() int {
return list.size
}
// Clear removes all elements from the list.
func (list *List) Clear() {
list.size = 0
list.elements = []interface{}{}
}
// Sort sorts values (in-place) using.
func (list *List) Sort(comparator utils.Comparator) {
if len(list.elements) < 2 {
return
}
utils.Sort(list.elements[:list.size], comparator)
}
// Swap swaps the two values at the specified positions.
func (list *List) Swap(i, j int) {
if list.withinRange(i) && list.withinRange(j) {
list.elements[i], list.elements[j] = list.elements[j], list.elements[i]
}
}
// Insert inserts values at specified index position shifting the value at that position (if any) and any subsequent elements to the right.
// Does not do anything if position is negative or bigger than list's size
// Note: position equal to list's size is valid, i.e. append.
func (list *List) Insert(index int, values ...interface{}) {
if !list.withinRange(index) {
// Append
if index == list.size {
list.Add(values...)
}
return
}
l := len(values)
list.growBy(l)
list.size += l
// Shift old to right
for i := list.size - 1; i >= index+l; i-- {
list.elements[i] = list.elements[i-l]
}
// Insert new
for i, value := range values {
list.elements[index+i] = value
}
}
// String returns a string representation of container
func (list *List) String() string {
str := "ArrayList\n"
values := []string{}
for _, value := range list.elements[:list.size] {
values = append(values, fmt.Sprintf("%v", value))
}
str += strings.Join(values, ", ")
return str
}
// Check that the index is within bounds of the list
func (list *List) withinRange(index int) bool {
return index >= 0 && index < list.size
}
func (list *List) resize(cap int) {
newElements := make([]interface{}, cap, cap)
copy(newElements, list.elements)
list.elements = newElements
}
// Expand the array if necessary, i.e. capacity will be reached if we add n elements
func (list *List) growBy(n int) {
// When capacity is reached, grow by a factor of growthFactor and add number of elements
currentCapacity := cap(list.elements)
if list.size+n >= currentCapacity {
newCapacity := int(growthFactor * float32(currentCapacity+n))
list.resize(newCapacity)
}
}
// Shrink the array if necessary, i.e. when size is shrinkFactor percent of current capacity
func (list *List) shrink() {
if shrinkFactor == 0.0 {
return
}
// Shrink when size is at shrinkFactor * capacity
currentCapacity := cap(list.elements)
if list.size <= int(float32(currentCapacity)*shrinkFactor) {
list.resize(list.size)
}
}

View File

@ -1,79 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package arraylist
import "github.com/emirpasic/gods/containers"
func assertEnumerableImplementation() {
var _ containers.EnumerableWithIndex = (*List)(nil)
}
// Each calls the given function once for each element, passing that element's index and value.
func (list *List) Each(f func(index int, value interface{})) {
iterator := list.Iterator()
for iterator.Next() {
f(iterator.Index(), iterator.Value())
}
}
// Map invokes the given function once for each element and returns a
// container containing the values returned by the given function.
func (list *List) Map(f func(index int, value interface{}) interface{}) *List {
newList := &List{}
iterator := list.Iterator()
for iterator.Next() {
newList.Add(f(iterator.Index(), iterator.Value()))
}
return newList
}
// Select returns a new container containing all elements for which the given function returns a true value.
func (list *List) Select(f func(index int, value interface{}) bool) *List {
newList := &List{}
iterator := list.Iterator()
for iterator.Next() {
if f(iterator.Index(), iterator.Value()) {
newList.Add(iterator.Value())
}
}
return newList
}
// Any passes each element of the collection to the given function and
// returns true if the function ever returns true for any element.
func (list *List) Any(f func(index int, value interface{}) bool) bool {
iterator := list.Iterator()
for iterator.Next() {
if f(iterator.Index(), iterator.Value()) {
return true
}
}
return false
}
// All passes each element of the collection to the given function and
// returns true if the function returns true for all elements.
func (list *List) All(f func(index int, value interface{}) bool) bool {
iterator := list.Iterator()
for iterator.Next() {
if !f(iterator.Index(), iterator.Value()) {
return false
}
}
return true
}
// Find passes each element of the container to the given function and returns
// the first (index,value) for which the function is true or -1,nil otherwise
// if no element matches the criteria.
func (list *List) Find(f func(index int, value interface{}) bool) (int, interface{}) {
iterator := list.Iterator()
for iterator.Next() {
if f(iterator.Index(), iterator.Value()) {
return iterator.Index(), iterator.Value()
}
}
return -1, nil
}

View File

@ -1,83 +0,0 @@
// Copyright (c) 2015, Emir Pasic. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package arraylist
import "github.com/emirpasic/gods/containers"
func assertIteratorImplementation() {
var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil)
}
// Iterator holding the iterator's state
type Iterator struct {
list *List
index int
}
// Iterator returns a stateful iterator whose values can be fetched by an index.
func (list *List) Iterator() Iterator {
return Iterator{list: list, index: -1}
}
// Next moves the iterator to the next element and returns true if there was a next element in the container.
// If Next() returns true, then next element's index and value can be retrieved by Index() and Value().
// If Next() was called for the first time, then it will point the iterator to the first element if it exists.
// Modifies the state of the iterator.
func (iterator *Iterator) Next() bool {
if iterator.index < iterator.list.size {
iterator.index++
}
return iterator.list.withinRange(iterator.index)
}
// Prev moves the iterator to the previous element and returns true if there was a previous element in the container.
// If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
func (iterator *Iterator) Prev() bool {
if iterator.index >= 0 {
iterator.index--
}
return iterator.list.withinRange(iterator.index)
}
// Value returns the current element's value.
// Does not modify the state of the iterator.
func (iterator *Iterator) Value() interface{} {
return iterator.list.elements[iterator.index]
}
// Index returns the current element's index.
// Does not modify the state of the iterator.
func (iterator *Iterator) Index() int {
return iterator.index
}
// Begin resets the iterator to its initial state (one-before-first)
// Call Next() to fetch the first element if any.
func (iterator *Iterator) Begin() {
iterator.index = -1
}
// End moves the iterator past the last element (one-past-the-end).
// Call Prev() to fetch the last element if any.
func (iterator *Iterator) End() {
iterator.index = iterator.list.size
}
// First moves the iterator to the first element and returns true if there was a first element in the container.
// If First() returns true, then first element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
func (iterator *Iterator) First() bool {
iterator.Begin()
return iterator.Next()
}
// Last moves the iterator to the last element and returns true if there was a last element in the container.
// If Last() returns true, then last element's index and value can be retrieved by Index() and Value().
// Modifies the state of the iterator.
func (iterator *Iterator) Last() bool {
iterator.End()
return iterator.Prev()
}

Some files were not shown because too many files have changed in this diff Show More