Merge pull request #484 from black-dragon74/DFBUGS-1214

DFBUGS-1214: [release-4.16] Non-linear parsing of case-insensitive content (CVE-2024-45338)
This commit is contained in:
openshift-merge-bot[bot] 2025-02-20 14:32:43 +00:00 committed by GitHub
commit 9463de4eff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1122 additions and 859 deletions

2
go.mod
View File

@ -28,7 +28,7 @@ require (
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.32.0 golang.org/x/crypto v0.32.0
golang.org/x/net v0.25.0 golang.org/x/net v0.34.0
golang.org/x/sys v0.29.0 golang.org/x/sys v0.29.0
google.golang.org/grpc v1.62.1 google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0 google.golang.org/protobuf v1.33.0

4
go.sum
View File

@ -1906,8 +1906,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=

4
vendor/golang.org/x/net/LICENSE generated vendored
View File

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the in the documentation and/or other materials provided with the
distribution. distribution.
* Neither the name of Google Inc. nor the names of its * Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from contributors may be used to endorse or promote products derived from
this software without specific prior written permission. this software without specific prior written permission.

View File

@ -78,16 +78,11 @@ example, to process each anchor node in depth-first order:
if err != nil { if err != nil {
// ... // ...
} }
var f func(*html.Node) for n := range doc.Descendants() {
f = func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "a" { if n.Type == html.ElementNode && n.Data == "a" {
// Do something with n... // Do something with n...
} }
for c := n.FirstChild; c != nil; c = c.NextSibling {
f(c)
} }
}
f(doc)
The relevant specifications include: The relevant specifications include:
https://html.spec.whatwg.org/multipage/syntax.html and https://html.spec.whatwg.org/multipage/syntax.html and

View File

@ -87,7 +87,7 @@ func parseDoctype(s string) (n *Node, quirks bool) {
} }
} }
if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" && if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" &&
strings.ToLower(lastAttr.Val) == "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd" { strings.EqualFold(lastAttr.Val, "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd") {
quirks = true quirks = true
} }
} }

View File

@ -40,8 +40,7 @@ func htmlIntegrationPoint(n *Node) bool {
if n.Data == "annotation-xml" { if n.Data == "annotation-xml" {
for _, a := range n.Attr { for _, a := range n.Attr {
if a.Key == "encoding" { if a.Key == "encoding" {
val := strings.ToLower(a.Val) if strings.EqualFold(a.Val, "text/html") || strings.EqualFold(a.Val, "application/xhtml+xml") {
if val == "text/html" || val == "application/xhtml+xml" {
return true return true
} }
} }

56
vendor/golang.org/x/net/html/iter.go generated vendored Normal file
View File

@ -0,0 +1,56 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.23
package html
import "iter"
// Ancestors returns an iterator over the ancestors of n, starting with n.Parent.
//
// Mutating a Node or its parents while iterating may have unexpected results.
func (n *Node) Ancestors() iter.Seq[*Node] {
_ = n.Parent // eager nil check
return func(yield func(*Node) bool) {
for p := n.Parent; p != nil && yield(p); p = p.Parent {
}
}
}
// ChildNodes returns an iterator over the immediate children of n,
// starting with n.FirstChild.
//
// Mutating a Node or its children while iterating may have unexpected results.
func (n *Node) ChildNodes() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
for c := n.FirstChild; c != nil && yield(c); c = c.NextSibling {
}
}
}
// Descendants returns an iterator over all nodes recursively beneath
// n, excluding n itself. Nodes are visited in depth-first preorder.
//
// Mutating a Node or its descendants while iterating may have unexpected results.
func (n *Node) Descendants() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
n.descendants(yield)
}
}
func (n *Node) descendants(yield func(*Node) bool) bool {
for c := range n.ChildNodes() {
if !yield(c) || !c.descendants(yield) {
return false
}
}
return true
}

View File

@ -38,6 +38,10 @@ var scopeMarker = Node{Type: scopeMarkerNode}
// that it looks like "a<b" rather than "a&lt;b". For element nodes, DataAtom // that it looks like "a<b" rather than "a&lt;b". For element nodes, DataAtom
// is the atom for Data, or zero if Data is not a known tag name. // is the atom for Data, or zero if Data is not a known tag name.
// //
// Node trees may be navigated using the link fields (Parent,
// FirstChild, and so on) or a range loop over iterators such as
// [Node.Descendants].
//
// An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace. // An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace.
// Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and // Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and
// "svg" is short for "http://www.w3.org/2000/svg". // "svg" is short for "http://www.w3.org/2000/svg".

View File

@ -840,6 +840,10 @@ func afterHeadIM(p *parser) bool {
p.parseImpliedToken(StartTagToken, a.Body, a.Body.String()) p.parseImpliedToken(StartTagToken, a.Body, a.Body.String())
p.framesetOK = true p.framesetOK = true
if p.tok.Type == ErrorToken {
// Stop parsing.
return true
}
return false return false
} }
@ -1031,7 +1035,7 @@ func inBodyIM(p *parser) bool {
if p.tok.DataAtom == a.Input { if p.tok.DataAtom == a.Input {
for _, t := range p.tok.Attr { for _, t := range p.tok.Attr {
if t.Key == "type" { if t.Key == "type" {
if strings.ToLower(t.Val) == "hidden" { if strings.EqualFold(t.Val, "hidden") {
// Skip setting framesetOK = false // Skip setting framesetOK = false
return true return true
} }
@ -1459,7 +1463,7 @@ func inTableIM(p *parser) bool {
return inHeadIM(p) return inHeadIM(p)
case a.Input: case a.Input:
for _, t := range p.tok.Attr { for _, t := range p.tok.Attr {
if t.Key == "type" && strings.ToLower(t.Val) == "hidden" { if t.Key == "type" && strings.EqualFold(t.Val, "hidden") {
p.addElement() p.addElement()
p.oe.pop() p.oe.pop()
return true return true

View File

@ -8,8 +8,8 @@ package http2
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"net"
"net/http" "net/http"
"sync" "sync"
) )
@ -158,7 +158,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) {
// This code decides which ones live or die. // This code decides which ones live or die.
// The return value used is whether c was used. // The return value used is whether c was used.
// c is never closed. // c is never closed.
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) { func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
p.mu.Lock() p.mu.Lock()
for _, cc := range p.conns[key] { for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() { if cc.CanTakeNewRequest() {
@ -194,8 +194,8 @@ type addConnCall struct {
err error err error
} }
func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) { func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
cc, err := t.NewClientConn(tc) cc, err := t.NewClientConn(nc)
p := c.p p := c.p
p.mu.Lock() p.mu.Lock()

122
vendor/golang.org/x/net/http2/config.go generated vendored Normal file
View File

@ -0,0 +1,122 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"math"
"net/http"
"time"
)
// http2Config is a package-internal version of net/http.HTTP2Config.
//
// http.HTTP2Config was added in Go 1.24.
// When running with a version of net/http that includes HTTP2Config,
// we merge the configuration with the fields in Transport or Server
// to produce an http2Config.
//
// Zero valued fields in http2Config are interpreted as in the
// net/http.HTTPConfig documentation.
//
// Precedence order for reconciling configurations is:
//
// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
// - Otherwise use the http2.{Server.Transport} value.
// - If the resulting value is zero or out of range, use a default.
type http2Config struct {
MaxConcurrentStreams uint32
MaxDecoderHeaderTableSize uint32
MaxEncoderHeaderTableSize uint32
MaxReadFrameSize uint32
MaxUploadBufferPerConnection int32
MaxUploadBufferPerStream int32
SendPingTimeout time.Duration
PingTimeout time.Duration
WriteByteTimeout time.Duration
PermitProhibitedCipherSuites bool
CountError func(errType string)
}
// configFromServer merges configuration settings from
// net/http.Server.HTTP2Config and http2.Server.
func configFromServer(h1 *http.Server, h2 *Server) http2Config {
conf := http2Config{
MaxConcurrentStreams: h2.MaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
CountError: h2.CountError,
}
fillNetHTTPServerConfig(&conf, h1)
setConfigDefaults(&conf, true)
return conf
}
// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2
// (the net/http Transport).
func configFromTransport(h2 *Transport) http2Config {
conf := http2Config{
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
}
// Unlike most config fields, where out-of-range values revert to the default,
// Transport.MaxReadFrameSize clips.
if conf.MaxReadFrameSize < minMaxFrameSize {
conf.MaxReadFrameSize = minMaxFrameSize
} else if conf.MaxReadFrameSize > maxFrameSize {
conf.MaxReadFrameSize = maxFrameSize
}
if h2.t1 != nil {
fillNetHTTPTransportConfig(&conf, h2.t1)
}
setConfigDefaults(&conf, false)
return conf
}
func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
if *v < minval || *v > maxval {
*v = defval
}
}
func setConfigDefaults(conf *http2Config, server bool) {
setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
if server {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
}
if server {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
}
setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
}
// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
func adjustHTTP1MaxHeaderSize(n int64) int64 {
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return n + typicalHeaders*perFieldOverhead
}

61
vendor/golang.org/x/net/http2/config_go124.go generated vendored Normal file
View File

@ -0,0 +1,61 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.24
package http2
import "net/http"
// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {
fillNetHTTPConfig(conf, srv.HTTP2)
}
// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2.
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {
fillNetHTTPConfig(conf, tr.HTTP2)
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}

16
vendor/golang.org/x/net/http2/config_pre_go124.go generated vendored Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.24
package http2
import "net/http"
// Pre-Go 1.24 fallback.
// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {}
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {}

View File

@ -1490,7 +1490,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
pf := mh.PseudoFields() pf := mh.PseudoFields()
for i, hf := range pf { for i, hf := range pf {
switch hf.Name { switch hf.Name {
case ":method", ":path", ":scheme", ":authority": case ":method", ":path", ":scheme", ":authority", ":protocol":
isRequest = true isRequest = true
case ":status": case ":status":
isResponse = true isResponse = true
@ -1498,7 +1498,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
return pseudoHeaderError(hf.Name) return pseudoHeaderError(hf.Name)
} }
// Check for duplicates. // Check for duplicates.
// This would be a bad algorithm, but N is 4. // This would be a bad algorithm, but N is 5.
// And this doesn't allocate. // And this doesn't allocate.
for _, hf2 := range pf[:i] { for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name { if hf.Name == hf2.Name {

View File

@ -17,15 +17,18 @@ package http2 // import "golang.org/x/net/http2"
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "net"
"net/http" "net/http"
"os" "os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
) )
@ -35,6 +38,7 @@ var (
logFrameWrites bool logFrameWrites bool
logFrameReads bool logFrameReads bool
inTests bool inTests bool
disableExtendedConnectProtocol bool
) )
func init() { func init() {
@ -47,6 +51,9 @@ func init() {
logFrameWrites = true logFrameWrites = true
logFrameReads = true logFrameReads = true
} }
if strings.Contains(e, "http2xconnect=0") {
disableExtendedConnectProtocol = true
}
} }
const ( const (
@ -138,6 +145,10 @@ func (s Setting) Valid() error {
if s.Val < 16384 || s.Val > 1<<24-1 { if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
} }
case SettingEnableConnectProtocol:
if s.Val != 1 && s.Val != 0 {
return ConnectionError(ErrCodeProtocol)
}
} }
return nil return nil
} }
@ -153,6 +164,7 @@ const (
SettingInitialWindowSize SettingID = 0x4 SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5 SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6 SettingMaxHeaderListSize SettingID = 0x6
SettingEnableConnectProtocol SettingID = 0x8
) )
var settingName = map[SettingID]string{ var settingName = map[SettingID]string{
@ -162,6 +174,7 @@ var settingName = map[SettingID]string{
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE", SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
} }
func (s SettingID) String() string { func (s SettingID) String() string {
@ -210,12 +223,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error) WriteString(s string) (n int, err error)
} }
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). // A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{} type closeWaiter chan struct{}
@ -242,12 +249,18 @@ func (cw closeWaiter) Wait() {
// idle memory usage with many connections. // idle memory usage with many connections.
type bufferedWriter struct { type bufferedWriter struct {
_ incomparable _ incomparable
w io.Writer // immutable group synctestGroupInterface // immutable
conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
} }
func newBufferedWriter(w io.Writer) *bufferedWriter { func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{w: w} return &bufferedWriter{
group: group,
conn: conn,
byteTimeout: timeout,
}
} }
// bufWriterPoolBufferSize is the size of bufio.Writer's // bufWriterPoolBufferSize is the size of bufio.Writer's
@ -274,7 +287,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) { func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil { if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer) bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w) bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw w.bw = bw
} }
return w.bw.Write(p) return w.bw.Write(p)
@ -292,6 +305,38 @@ func (w *bufferedWriter) Flush() error {
return err return err
} }
type bufferedWriterTimeoutWriter bufferedWriter
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
}
// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
var now time.Time
if group == nil {
now = time.Now()
} else {
now = group.Now()
}
conn.SetWriteDeadline(now.Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
// Either we finished the write, made no progress, or hit the deadline.
// Whichever it is, we're done now.
conn.SetWriteDeadline(time.Time{})
return n, err
}
}
}
func mustUint31(v int32) uint32 { func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 { if v < 0 || v > 2147483647 {
panic("out of range") panic("out of range")
@ -383,3 +428,14 @@ func validPseudoPath(v string) bool {
// makes that struct also non-comparable, and generally doesn't add // makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first). // any size (as long as it's first).
type incomparable [0]func() type incomparable [0]func()
// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
// It's defined as an interface here to let us keep synctestGroup entirely test-only
// and not a part of non-test builds.
type synctestGroupInterface interface {
Join()
Now() time.Time
NewTimer(d time.Duration) timer
AfterFunc(d time.Duration, f func()) timer
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
}

View File

@ -29,6 +29,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -56,6 +57,10 @@ const (
firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
handlerChunkWriteSize = 4 << 10 handlerChunkWriteSize = 4 << 10
defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000 maxQueuedControlFrames = 10000
) )
@ -127,6 +132,22 @@ type Server struct {
// If zero or negative, there is no timeout. // If zero or negative, there is no timeout.
IdleTimeout time.Duration IdleTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using a ping
// frame will be carried out if no frame is received on the connection.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to a ping is not received.
// If zero, a default of 15 seconds is used.
PingTimeout time.Duration
// WriteByteTimeout is the timeout after which a connection will be
// closed if no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
// If zero or negative, there is no timeout.
WriteByteTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow // MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not // control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1. // allow this to be smaller than 65535 or larger than 2^32-1.
@ -154,57 +175,39 @@ type Server struct {
// so that we don't embed a Mutex in this struct, which will make the // so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers. // struct non-copyable, which might break some callers.
state *serverInternalState state *serverInternalState
// Synchronization group used for testing.
// Outside of tests, this is nil.
group synctestGroupInterface
} }
func (s *Server) initialConnRecvWindowSize() int32 { func (s *Server) markNewGoroutine() {
if s.MaxUploadBufferPerConnection >= initialWindowSize { if s.group != nil {
return s.MaxUploadBufferPerConnection s.group.Join()
} }
return 1 << 20
} }
func (s *Server) initialStreamRecvWindowSize() int32 { func (s *Server) now() time.Time {
if s.MaxUploadBufferPerStream > 0 { if s.group != nil {
return s.MaxUploadBufferPerStream return s.group.Now()
} }
return 1 << 20 return time.Now()
} }
func (s *Server) maxReadFrameSize() uint32 { // newTimer creates a new time.Timer, or a synthetic timer in tests.
if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize { func (s *Server) newTimer(d time.Duration) timer {
return v if s.group != nil {
return s.group.NewTimer(d)
} }
return defaultMaxReadFrameSize return timeTimer{time.NewTimer(d)}
} }
func (s *Server) maxConcurrentStreams() uint32 { // afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
if v := s.MaxConcurrentStreams; v > 0 { func (s *Server) afterFunc(d time.Duration, f func()) timer {
return v if s.group != nil {
return s.group.AfterFunc(d, f)
} }
return defaultMaxStreams return timeTimer{time.AfterFunc(d, f)}
}
func (s *Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (s *Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
func (s *Server) maxQueuedControlFrames() int {
// TODO: if anybody asks, add a Server field, and remember to define the
// behavior of negative values.
return maxQueuedControlFrames
} }
type serverInternalState struct { type serverInternalState struct {
@ -303,7 +306,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSNextProto == nil { if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
} }
protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
if testHookOnConn != nil { if testHookOnConn != nil {
testHookOnConn() testHookOnConn()
} }
@ -323,9 +326,28 @@ func ConfigureServer(s *http.Server, conf *Server) error {
Context: ctx, Context: ctx,
Handler: h, Handler: h,
BaseConfig: hs, BaseConfig: hs,
SawClientPreface: sawClientPreface,
}) })
} }
s.TLSNextProto[NextProtoTLS] = protoHandler s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
protoHandler(hs, c, h, false)
}
// The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
//
// A connection passed in this method has already had the HTTP/2 preface read from it.
s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
nc, err := unencryptedNetConnFromTLSConn(c)
if err != nil {
if lg := hs.ErrorLog; lg != nil {
lg.Print(err)
} else {
log.Print(err)
}
go c.Close()
return
}
protoHandler(hs, nc, h, true)
}
return nil return nil
} }
@ -400,16 +422,22 @@ func (o *ServeConnOpts) handler() http.Handler {
// //
// The opts parameter is optional. If nil, default values are used. // The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
s.serveConn(c, opts, nil)
}
func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts) baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel() defer cancel()
http1srv := opts.baseConfig()
conf := configFromServer(http1srv, s)
sc := &serverConn{ sc := &serverConn{
srv: s, srv: s,
hs: opts.baseConfig(), hs: http1srv,
conn: c, conn: c,
baseCtx: baseCtx, baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(), remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c), bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
handler: opts.handler(), handler: opts.handler(),
streams: make(map[uint32]*stream), streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult), readFrameCh: make(chan readFrameResult),
@ -419,13 +447,19 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}), doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(), advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize, initialStreamSendWindowSize: initialWindowSize,
initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize, maxFrameSize: initialMaxFrameSize,
pingTimeout: conf.PingTimeout,
countErrorFunc: conf.CountError,
serveG: newGoroutineLock(), serveG: newGoroutineLock(),
pushEnabled: true, pushEnabled: true,
sawClientPreface: opts.SawClientPreface, sawClientPreface: opts.SawClientPreface,
} }
if newf != nil {
newf(sc)
}
s.state.registerConn(sc) s.state.registerConn(sc)
defer s.state.unregisterConn(sc) defer s.state.unregisterConn(sc)
@ -451,15 +485,15 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize) sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
if s.CountError != nil { if conf.CountError != nil {
fr.countError = s.CountError fr.countError = conf.CountError
} }
fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil) fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize()) fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr sc.framer = fr
if tc, ok := c.(connectionStater); ok { if tc, ok := c.(connectionStater); ok {
@ -492,7 +526,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// So for now, do nothing here again. // So for now, do nothing here again.
} }
if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error // "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of // (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated." // the prohibited cipher suites are negotiated."
@ -529,7 +563,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
opts.UpgradeRequest = nil opts.UpgradeRequest = nil
} }
sc.serve() sc.serve(conf)
} }
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@ -569,6 +603,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string remoteAddrStr string
writeSched WriteScheduler writeSched WriteScheduler
countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check(): // Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve() serveG goroutineLock // used to verify funcs are on serve()
@ -588,6 +623,7 @@ type serverConn struct {
streams map[uint32]*stream streams map[uint32]*stream
unstartedHandlers []unstartedHandler unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
initialStreamRecvWindowSize int32
maxFrameSize int32 maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@ -598,9 +634,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write needToSendGoAway bool // we need to schedule a GOAWAY frame write
pingSent bool
sentPingData [8]byte
goAwayCode ErrCode goAwayCode ErrCode
shutdownTimer *time.Timer // nil until used shutdownTimer timer // nil until used
idleTimer *time.Timer // nil if unused idleTimer timer // nil if unused
readIdleTimeout time.Duration
pingTimeout time.Duration
readIdleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine: // Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer headerWriteBuf bytes.Buffer
@ -615,11 +656,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 { if n <= 0 {
n = http.DefaultMaxHeaderBytes n = http.DefaultMaxHeaderBytes
} }
// http2's count is in a slightly different unit and includes 32 bytes per pair. return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return uint32(n + typicalHeaders*perFieldOverhead)
} }
func (sc *serverConn) curOpenStreams() uint32 { func (sc *serverConn) curOpenStreams() uint32 {
@ -652,8 +689,8 @@ type stream struct {
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100) wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused readDeadline timer // nil if unused
writeDeadline *time.Timer // nil if unused writeDeadline timer // nil if unused
closeErr error // set before cw is closed closeErr error // set before cw is closed
trailer http.Header // accumulated trailers trailer http.Header // accumulated trailers
@ -811,8 +848,9 @@ type readFrameResult struct {
// consumer is done with the frame. // consumer is done with the frame.
// It's run on its own goroutine. // It's run on its own goroutine.
func (sc *serverConn) readFrames() { func (sc *serverConn) readFrames() {
gate := make(gate) sc.srv.markNewGoroutine()
gateDone := gate.Done gate := make(chan struct{})
gateDone := func() { gate <- struct{}{} }
for { for {
f, err := sc.framer.ReadFrame() f, err := sc.framer.ReadFrame()
select { select {
@ -843,6 +881,7 @@ type frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per // At most one goroutine can be running writeFrameAsync at a time per
// serverConn. // serverConn.
func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
sc.srv.markNewGoroutine()
var err error var err error
if wd == nil { if wd == nil {
err = wr.write.writeFrame(sc) err = wr.write.writeFrame(sc)
@ -881,7 +920,7 @@ func (sc *serverConn) notePanic() {
} }
} }
func (sc *serverConn) serve() { func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check() sc.serveG.check()
defer sc.notePanic() defer sc.notePanic()
defer sc.conn.Close() defer sc.conn.Close()
@ -893,20 +932,24 @@ func (sc *serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
} }
sc.writeFrame(FrameWriteRequest{ settings := writeSettings{
write: writeSettings{ {SettingMaxFrameSize, conf.MaxReadFrameSize},
{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{SettingMaxConcurrentStreams, sc.advMaxStreams}, {SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()}, {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()}, {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
{SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
}, }
if !disableExtendedConnectProtocol {
settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
}
sc.writeFrame(FrameWriteRequest{
write: settings,
}) })
sc.unackedSettings++ sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens. // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens. // If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff)) sc.sendWindowUpdate(nil, int(diff))
} }
@ -922,15 +965,22 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 { if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop() defer sc.idleTimer.Stop()
} }
if conf.SendPingTimeout > 0 {
sc.readIdleTimeout = conf.SendPingTimeout
sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
defer sc.readIdleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop() defer settingsTimer.Stop()
lastFrameTime := sc.srv.now()
loopNum := 0 loopNum := 0
for { for {
loopNum++ loopNum++
@ -944,6 +994,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh: case res := <-sc.wroteFrameCh:
sc.wroteFrame(res) sc.wroteFrame(res)
case res := <-sc.readFrameCh: case res := <-sc.readFrameCh:
lastFrameTime = sc.srv.now()
// Process any written frames before reading new frames from the client since a // Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started. // written frame could have triggered a new stream to be started.
if sc.writingFrameAsync { if sc.writingFrameAsync {
@ -975,6 +1026,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg: case idleTimerMsg:
sc.vlogf("connection is idle") sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo) sc.goAway(ErrCodeNo)
case readIdleTimerMsg:
sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg: case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return return
@ -997,7 +1050,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames, // If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us // but not reading them from us, assume they are trying to make us
// run out of memory. // run out of memory.
if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection") sc.vlogf("http2: too many control frames in send queue, closing connection")
return return
} }
@ -1013,12 +1066,39 @@ func (sc *serverConn) serve() {
} }
} }
func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
if sc.pingSent {
sc.vlogf("timeout waiting for PING response")
sc.conn.Close()
return
}
pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
now := sc.srv.now()
if pingAt.After(now) {
// We received frames since arming the ping timer.
// Reset it for the next possible timeout.
sc.readIdleTimer.Reset(pingAt.Sub(now))
return
}
sc.pingSent = true
// Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
// is we send a PING frame containing 0s.
_, _ = rand.Read(sc.sentPingData[:])
sc.writeFrame(FrameWriteRequest{
write: &writePing{data: sc.sentPingData},
})
sc.readIdleTimer.Reset(sc.pingTimeout)
}
type serverMessage int type serverMessage int
// Message values sent to serveMsgCh. // Message values sent to serveMsgCh.
var ( var (
settingsTimerMsg = new(serverMessage) settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage) idleTimerMsg = new(serverMessage)
readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage) handlerDoneMsg = new(serverMessage)
@ -1026,6 +1106,7 @@ var (
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) } func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) } func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) { func (sc *serverConn) sendServeMsg(msg interface{}) {
@ -1057,10 +1138,10 @@ func (sc *serverConn) readPreface() error {
errc <- nil errc <- nil
} }
}() }()
timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop() defer timer.Stop()
select { select {
case <-timer.C: case <-timer.C():
return errPrefaceTimeout return errPrefaceTimeout
case err := <-errc: case err := <-errc:
if err == nil { if err == nil {
@ -1278,6 +1359,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false sc.writingFrame = false
sc.writingFrameAsync = false sc.writingFrameAsync = false
if res.err != nil {
sc.conn.Close()
}
wr := res.wr wr := res.wr
if writeEndsStream(wr.write) { if writeEndsStream(wr.write) {
@ -1425,7 +1510,7 @@ func (sc *serverConn) goAway(code ErrCode) {
func (sc *serverConn) shutDownIn(d time.Duration) { func (sc *serverConn) shutDownIn(d time.Duration) {
sc.serveG.check() sc.serveG.check()
sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
} }
func (sc *serverConn) resetStream(se StreamError) { func (sc *serverConn) resetStream(se StreamError) {
@ -1552,6 +1637,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error { func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check() sc.serveG.check()
if f.IsAck() { if f.IsAck() {
if sc.pingSent && sc.sentPingData == f.Data {
// This is a response to a PING we sent.
sc.pingSent = false
sc.readIdleTimer.Reset(sc.readIdleTimeout)
}
// 6.7 PING: " An endpoint MUST NOT respond to PING frames // 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag." // containing this flag."
return nil return nil
@ -1639,7 +1729,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id) delete(sc.streams, st.id)
if len(sc.streams) == 0 { if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 { if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout) sc.idleTimer.Reset(sc.srv.IdleTimeout)
} }
if h1ServerKeepAlivesDisabled(sc.hs) { if h1ServerKeepAlivesDisabled(sc.hs) {
@ -1661,6 +1751,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
} }
} }
st.closeErr = err st.closeErr = err
st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id) sc.writeSched.CloseStream(st.id)
} }
@ -1714,6 +1805,9 @@ func (sc *serverConn) processSetting(s Setting) error {
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case SettingMaxHeaderListSize: case SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val sc.peerMaxHeaderListSize = s.Val
case SettingEnableConnectProtocol:
// Receipt of this parameter by a server does not
// have any impact
default: default:
// Unknown setting: "An endpoint that receives a SETTINGS // Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST // frame with any unknown or unsupported identifier MUST
@ -2021,7 +2115,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 { if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
} }
return sc.scheduleHandler(id, rw, req, handler) return sc.scheduleHandler(id, rw, req, handler)
@ -2117,9 +2211,9 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init() st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize) st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize()) st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 { if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
} }
sc.streams[id] = st sc.streams[id] = st
@ -2144,11 +2238,17 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
scheme: f.PseudoValue("scheme"), scheme: f.PseudoValue("scheme"),
authority: f.PseudoValue("authority"), authority: f.PseudoValue("authority"),
path: f.PseudoValue("path"), path: f.PseudoValue("path"),
protocol: f.PseudoValue("protocol"),
}
// extended connect is disabled, so we should not see :protocol
if disableExtendedConnectProtocol && rp.protocol != "" {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
} }
isConnect := rp.method == "CONNECT" isConnect := rp.method == "CONNECT"
if isConnect { if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" { if rp.protocol == "" && (rp.path != "" || rp.scheme != "" || rp.authority == "") {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
} }
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
@ -2172,6 +2272,9 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
if rp.authority == "" { if rp.authority == "" {
rp.authority = rp.header.Get("Host") rp.authority = rp.header.Get("Host")
} }
if rp.protocol != "" {
rp.header.Set(":protocol", rp.protocol)
}
rw, req, err := sc.newWriterAndRequestNoBody(st, rp) rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
if err != nil { if err != nil {
@ -2198,6 +2301,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
type requestParam struct { type requestParam struct {
method string method string
scheme, authority, path string scheme, authority, path string
protocol string
header http.Header header http.Header
} }
@ -2239,7 +2343,7 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*r
var url_ *url.URL var url_ *url.URL
var requestURI string var requestURI string
if rp.method == "CONNECT" { if rp.method == "CONNECT" && rp.protocol == "" {
url_ = &url.URL{Host: rp.authority} url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior requestURI = rp.authority // mimic HTTP/1 server behavior
} else { } else {
@ -2343,6 +2447,7 @@ func (sc *serverConn) handlerDone() {
// Run on its own goroutine. // Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
sc.srv.markNewGoroutine()
defer sc.sendServeMsg(handlerDoneMsg) defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true didPanic := true
defer func() { defer func() {
@ -2639,7 +2744,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string var date string
if _, ok := rws.snapHeader["Date"]; !ok { if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure. // TODO(bradfitz): be faster here, like net/http? measure.
date = time.Now().UTC().Format(http.TimeFormat) date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
} }
for _, v := range rws.snapHeader["Trailer"] { for _, v := range rws.snapHeader["Trailer"] {
@ -2761,7 +2866,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
func (w *responseWriter) SetReadDeadline(deadline time.Time) error { func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) { if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately // If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail. // so writes after SetWriteDeadline returns will fail.
st.onReadTimeout() st.onReadTimeout()
@ -2777,9 +2882,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() { if deadline.IsZero() {
st.readDeadline = nil st.readDeadline = nil
} else if st.readDeadline == nil { } else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
} else { } else {
st.readDeadline.Reset(deadline.Sub(time.Now())) st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
} }
}) })
return nil return nil
@ -2787,7 +2892,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) { if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately // If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail. // so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout() st.onWriteTimeout()
@ -2803,14 +2908,19 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() { if deadline.IsZero() {
st.writeDeadline = nil st.writeDeadline = nil
} else if st.writeDeadline == nil { } else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
} else { } else {
st.writeDeadline.Reset(deadline.Sub(time.Now())) st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
} }
}) })
return nil return nil
} }
func (w *responseWriter) EnableFullDuplex() error {
// We always support full duplex responses, so this is a no-op.
return nil
}
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
w.FlushError() w.FlushError()
} }
@ -3257,7 +3367,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil { if sc == nil || sc.srv == nil {
return err return err
} }
f := sc.srv.CountError f := sc.countErrorFunc
if f == nil { if f == nil {
return err return err
} }

View File

@ -1,331 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

20
vendor/golang.org/x/net/http2/timer.go generated vendored Normal file
View File

@ -0,0 +1,20 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import "time"
// A timer is a time.Timer, as an interface which can be replaced in tests.
type timer = interface {
C() <-chan time.Time
Reset(d time.Duration) bool
Stop() bool
}
// timeTimer adapts a time.Timer to the timer interface.
type timeTimer struct {
*time.Timer
}
func (t timeTimer) C() <-chan time.Time { return t.Timer.C }

File diff suppressed because it is too large Load Diff

32
vendor/golang.org/x/net/http2/unencrypted.go generated vendored Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"crypto/tls"
"errors"
"net"
)
const nextProtoUnencryptedHTTP2 = "unencrypted_http2"
// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn.
//
// TLSNextProto functions accept a *tls.Conn.
//
// When passing an unencrypted HTTP/2 connection to a TLSNextProto function,
// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection.
// To be extra careful about mistakes (accidentally dropping TLS encryption in a place
// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method
// that returns the actual connection we want to use.
func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) {
conner, ok := tc.NetConn().(interface {
UnencryptedNetConn() net.Conn
})
if !ok {
return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff")
}
return conner.UnencryptedNetConn(), nil
}

View File

@ -131,6 +131,16 @@ func (se StreamError) writeFrame(ctx writeContext) error {
func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max }
type writePing struct {
data [8]byte
}
func (w writePing) writeFrame(ctx writeContext) error {
return ctx.Framer().WritePing(false, w.data)
}
func (w writePing) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.data) <= max }
type writePingAck struct{ pf *PingFrame } type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error { func (w writePingAck) writeFrame(ctx writeContext) error {

View File

@ -443,8 +443,8 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
} }
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { func (ws *priorityWriteScheduler) removeNode(n *priorityNode) {
for k := n.kids; k != nil; k = k.next { for n.kids != nil {
k.setParent(n.parent) n.kids.setParent(n.parent)
} }
n.setParent(nil) n.setParent(nil)
delete(ws.nodes, n.id) delete(ws.nodes, n.id)

View File

@ -137,9 +137,7 @@ func (p *PerHost) AddNetwork(net *net.IPNet) {
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains. // "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) { func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") { zone = strings.TrimSuffix(zone, ".")
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") { if !strings.HasPrefix(zone, ".") {
zone = "." + zone zone = "." + zone
} }
@ -148,8 +146,6 @@ func (p *PerHost) AddZone(zone string) {
// AddHost specifies a host name that will use the bypass proxy. // AddHost specifies a host name that will use the bypass proxy.
func (p *PerHost) AddHost(host string) { func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") { host = strings.TrimSuffix(host, ".")
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host) p.bypassHosts = append(p.bypassHosts, host)
} }

View File

@ -16,7 +16,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -279,7 +278,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
} }
} }
if header := frame.HeaderReader(); header != nil { if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header) io.Copy(io.Discard, header)
} }
switch frame.PayloadType() { switch frame.PayloadType() {
case ContinuationFrame: case ContinuationFrame:
@ -294,7 +293,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err return nil, err
} }
io.Copy(ioutil.Discard, frame) io.Copy(io.Discard, frame)
if frame.PayloadType() == PingFrame { if frame.PayloadType() == PingFrame {
if _, err := handler.WritePong(b[:n]); err != nil { if _, err := handler.WritePong(b[:n]); err != nil {
return nil, err return nil, err

View File

@ -8,7 +8,7 @@
// This package currently lacks some features found in an alternative // This package currently lacks some features found in an alternative
// and more actively maintained WebSocket package: // and more actively maintained WebSocket package:
// //
// https://pkg.go.dev/nhooyr.io/websocket // https://pkg.go.dev/github.com/coder/websocket
package websocket // import "golang.org/x/net/websocket" package websocket // import "golang.org/x/net/websocket"
import ( import (
@ -17,7 +17,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -208,7 +207,7 @@ again:
n, err = ws.frameReader.Read(msg) n, err = ws.frameReader.Read(msg)
if err == io.EOF { if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil { if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer) io.Copy(io.Discard, trailer)
} }
ws.frameReader = nil ws.frameReader = nil
goto again goto again
@ -330,7 +329,7 @@ func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock() ws.rio.Lock()
defer ws.rio.Unlock() defer ws.rio.Unlock()
if ws.frameReader != nil { if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader) _, err = io.Copy(io.Discard, ws.frameReader)
if err != nil { if err != nil {
return err return err
} }
@ -362,7 +361,7 @@ again:
return ErrFrameTooLarge return ErrFrameTooLarge
} }
payloadType := frame.PayloadType() payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame) data, err := io.ReadAll(frame)
if err != nil { if err != nil {
return err return err
} }

2
vendor/modules.txt vendored
View File

@ -734,7 +734,7 @@ golang.org/x/crypto/ssh/internal/bcrypt_pbkdf
golang.org/x/exp/constraints golang.org/x/exp/constraints
golang.org/x/exp/maps golang.org/x/exp/maps
golang.org/x/exp/slices golang.org/x/exp/slices
# golang.org/x/net v0.25.0 # golang.org/x/net v0.34.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/net/context golang.org/x/net/context
golang.org/x/net/html golang.org/x/net/html