diff options
author | Christopher Speller <crspeller@gmail.com> | 2016-11-16 19:28:52 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-16 19:28:52 -0500 |
commit | 0135904f7d3e1c0e763adaefe267c736616e3d26 (patch) | |
tree | c27be7588f98eaea62e0bd0c0087f2b348da9738 /vendor/golang.org/x/net/http2 | |
parent | 0b296dd8c2aefefe89787be5cc627d44cf431150 (diff) | |
download | chat-0135904f7d3e1c0e763adaefe267c736616e3d26.tar.gz chat-0135904f7d3e1c0e763adaefe267c736616e3d26.tar.bz2 chat-0135904f7d3e1c0e763adaefe267c736616e3d26.zip |
Upgrading server dependancies (#4566)
Diffstat (limited to 'vendor/golang.org/x/net/http2')
22 files changed, 3386 insertions, 790 deletions
diff --git a/vendor/golang.org/x/net/http2/go17.go b/vendor/golang.org/x/net/http2/go17.go index 730319dd5..47b7fae08 100644 --- a/vendor/golang.org/x/net/http2/go17.go +++ b/vendor/golang.org/x/net/http2/go17.go @@ -39,6 +39,13 @@ type clientTrace httptrace.ClientTrace func reqContext(r *http.Request) context.Context { return r.Context() } +func (t *Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + func setResponseUncompressed(res *http.Response) { res.Uncompressed = true } func traceGotConn(req *http.Request, cc *ClientConn) { @@ -92,3 +99,8 @@ func requestTrace(req *http.Request) *clientTrace { trace := httptrace.ContextClientTrace(req.Context()) return (*clientTrace)(trace) } + +// Ping sends a PING frame to the server and waits for the ack. +func (cc *ClientConn) Ping(ctx context.Context) error { + return cc.ping(ctx) +} diff --git a/vendor/golang.org/x/net/http2/go18.go b/vendor/golang.org/x/net/http2/go18.go index c2ae16731..8c0dd2508 100644 --- a/vendor/golang.org/x/net/http2/go18.go +++ b/vendor/golang.org/x/net/http2/go18.go @@ -6,6 +6,36 @@ package http2 -import "crypto/tls" +import ( + "crypto/tls" + "net/http" +) func cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +var _ http.Pusher = (*responseWriter)(nil) + +// Push implements http.Pusher. +func (w *responseWriter) Push(target string, opts *http.PushOptions) error { + internalOpts := pushOptions{} + if opts != nil { + internalOpts.Method = opts.Method + internalOpts.Header = opts.Header + } + return w.push(target, internalOpts) +} + +func configureServer18(h1 *http.Server, h2 *Server) error { + if h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + return nil +} + +func shouldLogPanic(panicValue interface{}) bool { + return panicValue != nil && panicValue != http.ErrAbortHandler +} diff --git a/vendor/golang.org/x/net/http2/go18_test.go b/vendor/golang.org/x/net/http2/go18_test.go new file mode 100644 index 000000000..836550597 --- /dev/null +++ b/vendor/golang.org/x/net/http2/go18_test.go @@ -0,0 +1,66 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package http2 + +import ( + "net/http" + "testing" + "time" +) + +// Tests that http2.Server.IdleTimeout is initialized from +// http.Server.{Idle,Read}Timeout. http.Server.IdleTimeout was +// added in Go 1.8. +func TestConfigureServerIdleTimeout_Go18(t *testing.T) { + const timeout = 5 * time.Second + const notThisOne = 1 * time.Second + + // With a zero http2.Server, verify that it copies IdleTimeout: + { + s1 := &http.Server{ + IdleTimeout: timeout, + ReadTimeout: notThisOne, + } + s2 := &Server{} + if err := ConfigureServer(s1, s2); err != nil { + t.Fatal(err) + } + if s2.IdleTimeout != timeout { + t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) + } + } + + // And that it falls back to ReadTimeout: + { + s1 := &http.Server{ + ReadTimeout: timeout, + } + s2 := &Server{} + if err := ConfigureServer(s1, s2); err != nil { + t.Fatal(err) + } + if s2.IdleTimeout != timeout { + t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) + } + } + + // Verify that s1's IdleTimeout doesn't overwrite an existing setting: + { + s1 := &http.Server{ + IdleTimeout: notThisOne, + } + s2 := &Server{ + IdleTimeout: timeout, + } + if err := ConfigureServer(s1, s2); err != nil { + t.Fatal(err) + } + if s2.IdleTimeout != timeout { + t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) + } + } +} diff --git a/vendor/golang.org/x/net/http2/h2demo/h2demo.go b/vendor/golang.org/x/net/http2/h2demo/h2demo.go index a248d479e..980b6d67d 100644 --- a/vendor/golang.org/x/net/http2/h2demo/h2demo.go +++ b/vendor/golang.org/x/net/http2/h2demo/h2demo.go @@ -19,6 +19,7 @@ import ( "log" "net" "net/http" + "os" "path" "regexp" "runtime" @@ -27,8 +28,8 @@ import ( "sync" "time" - "camlistore.org/pkg/googlestorage" "go4.org/syncutil/singleflight" + "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" ) @@ -378,37 +379,18 @@ func httpHost() string { } func serveProdTLS() error { - c, err := googlestorage.NewServiceClient() - if err != nil { + const cacheDir = "/var/cache/autocert" + if err := os.MkdirAll(cacheDir, 0700); err != nil { return err } - slurp := func(key string) ([]byte, error) { - const bucket = "http2-demo-server-tls" - rc, _, err := c.GetObject(&googlestorage.Object{ - Bucket: bucket, - Key: key, - }) - if err != nil { - return nil, fmt.Errorf("Error fetching GCS object %q in bucket %q: %v", key, bucket, err) - } - defer rc.Close() - return ioutil.ReadAll(rc) - } - certPem, err := slurp("http2.golang.org.chained.pem") - if err != nil { - return err - } - keyPem, err := slurp("http2.golang.org.key") - if err != nil { - return err - } - cert, err := tls.X509KeyPair(certPem, keyPem) - if err != nil { - return err + m := autocert.Manager{ + Cache: autocert.DirCache(cacheDir), + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist("http2.golang.org"), } srv := &http.Server{ TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, + GetCertificate: m.GetCertificate, }, } http2.ConfigureServer(srv, &http2.Server{}) diff --git a/vendor/golang.org/x/net/http2/h2i/h2i.go b/vendor/golang.org/x/net/http2/h2i/h2i.go index b70976f77..228edf8a4 100644 --- a/vendor/golang.org/x/net/http2/h2i/h2i.go +++ b/vendor/golang.org/x/net/http2/h2i/h2i.go @@ -168,7 +168,7 @@ func (app *h2i) Main() error { app.framer = http2.NewFramer(tc, tc) - oldState, err := terminal.MakeRaw(0) + oldState, err := terminal.MakeRaw(int(os.Stdin.Fd())) if err != nil { return err } @@ -238,7 +238,7 @@ func (app *h2i) Main() error { } func (app *h2i) logf(format string, args ...interface{}) { - fmt.Fprintf(app.term, format+"\n", args...) + fmt.Fprintf(app.term, format+"\r\n", args...) } func (app *h2i) readConsole() error { @@ -435,9 +435,9 @@ func (app *h2i) readFrames() error { return nil }) case *http2.WindowUpdateFrame: - app.logf(" Window-Increment = %v\n", f.Increment) + app.logf(" Window-Increment = %v", f.Increment) case *http2.GoAwayFrame: - app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)\n", f.LastStreamID, f.ErrCode, f.ErrCode) + app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)", f.LastStreamID, f.ErrCode, f.ErrCode) case *http2.DataFrame: app.logf(" %q", f.Data()) case *http2.HeadersFrame: diff --git a/vendor/golang.org/x/net/http2/http2.go b/vendor/golang.org/x/net/http2/http2.go index 2e27b093c..b6b0f9ad1 100644 --- a/vendor/golang.org/x/net/http2/http2.go +++ b/vendor/golang.org/x/net/http2/http2.go @@ -36,6 +36,7 @@ var ( VerboseLogs bool logFrameWrites bool logFrameReads bool + inTests bool ) func init() { @@ -77,13 +78,23 @@ var ( type streamState int +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. const ( stateIdle streamState = iota stateOpen stateHalfClosedLocal stateHalfClosedRemote - stateResvLocal - stateResvRemote stateClosed ) @@ -92,8 +103,6 @@ var stateName = [...]string{ stateOpen: "Open", stateHalfClosedLocal: "HalfClosedLocal", stateHalfClosedRemote: "HalfClosedRemote", - stateResvLocal: "ResvLocal", - stateResvRemote: "ResvRemote", stateClosed: "Closed", } @@ -253,14 +262,27 @@ func newBufferedWriter(w io.Writer) *bufferedWriter { return &bufferedWriter{w: w} } +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const bufWriterPoolBufferSize = 4 << 10 + var bufWriterPool = sync.Pool{ New: func() interface{} { - // TODO: pick something better? this is a bit under - // (3 x typical 1500 byte MTU) at least. - return bufio.NewWriterSize(nil, 4<<10) + return bufio.NewWriterSize(nil, bufWriterPoolBufferSize) }, } +func (w *bufferedWriter) Available() int { + if w.bw == nil { + return bufWriterPoolBufferSize + } + return w.bw.Available() +} + func (w *bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := bufWriterPool.Get().(*bufio.Writer) diff --git a/vendor/golang.org/x/net/http2/http2_test.go b/vendor/golang.org/x/net/http2/http2_test.go index 22c2ace82..524877647 100644 --- a/vendor/golang.org/x/net/http2/http2_test.go +++ b/vendor/golang.org/x/net/http2/http2_test.go @@ -27,6 +27,7 @@ func condSkipFailingTest(t *testing.T) { } func init() { + inTests = true DebugGoroutines = true flag.BoolVar(&VerboseLogs, "verboseh2", VerboseLogs, "Verbose HTTP/2 debug logging") } diff --git a/vendor/golang.org/x/net/http2/not_go17.go b/vendor/golang.org/x/net/http2/not_go17.go index 667867f4d..140434a79 100644 --- a/vendor/golang.org/x/net/http2/not_go17.go +++ b/vendor/golang.org/x/net/http2/not_go17.go @@ -10,9 +10,13 @@ import ( "crypto/tls" "net" "net/http" + "time" ) -type contextContext interface{} +type contextContext interface { + Done() <-chan struct{} + Err() error +} type fakeContext struct{} @@ -75,3 +79,9 @@ func cloneTLSConfig(c *tls.Config) *tls.Config { CurvePreferences: c.CurvePreferences, } } + +func (cc *ClientConn) Ping(ctx contextContext) error { + return cc.ping(ctx) +} + +func (t *Transport) idleConnTimeout() time.Duration { return 0 } diff --git a/vendor/golang.org/x/net/http2/not_go18.go b/vendor/golang.org/x/net/http2/not_go18.go new file mode 100644 index 000000000..2e600dc35 --- /dev/null +++ b/vendor/golang.org/x/net/http2/not_go18.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package http2 + +import "net/http" + +func configureServer18(h1 *http.Server, h2 *Server) error { + // No IdleTimeout to sync prior to Go 1.8. + return nil +} + +func shouldLogPanic(panicValue interface{}) bool { + return panicValue != nil +} diff --git a/vendor/golang.org/x/net/http2/priority_test.go b/vendor/golang.org/x/net/http2/priority_test.go deleted file mode 100644 index a3fe2bb49..000000000 --- a/vendor/golang.org/x/net/http2/priority_test.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "testing" -) - -func TestPriority(t *testing.T) { - // A -> B - // move A's parent to B - streams := make(map[uint32]*stream) - a := &stream{ - parent: nil, - weight: 16, - } - streams[1] = a - b := &stream{ - parent: a, - weight: 16, - } - streams[2] = b - adjustStreamPriority(streams, 1, PriorityParam{ - Weight: 20, - StreamDep: 2, - }) - if a.parent != b { - t.Errorf("Expected A's parent to be B") - } - if a.weight != 20 { - t.Errorf("Expected A's weight to be 20; got %d", a.weight) - } - if b.parent != nil { - t.Errorf("Expected B to have no parent") - } - if b.weight != 16 { - t.Errorf("Expected B's weight to be 16; got %d", b.weight) - } -} - -func TestPriorityExclusiveZero(t *testing.T) { - // A B and C are all children of the 0 stream. - // Exclusive reprioritization to any of the streams - // should bring the rest of the streams under the - // reprioritized stream - streams := make(map[uint32]*stream) - a := &stream{ - parent: nil, - weight: 16, - } - streams[1] = a - b := &stream{ - parent: nil, - weight: 16, - } - streams[2] = b - c := &stream{ - parent: nil, - weight: 16, - } - streams[3] = c - adjustStreamPriority(streams, 3, PriorityParam{ - Weight: 20, - StreamDep: 0, - Exclusive: true, - }) - if a.parent != c { - t.Errorf("Expected A's parent to be C") - } - if a.weight != 16 { - t.Errorf("Expected A's weight to be 16; got %d", a.weight) - } - if b.parent != c { - t.Errorf("Expected B's parent to be C") - } - if b.weight != 16 { - t.Errorf("Expected B's weight to be 16; got %d", b.weight) - } - if c.parent != nil { - t.Errorf("Expected C to have no parent") - } - if c.weight != 20 { - t.Errorf("Expected C's weight to be 20; got %d", b.weight) - } -} - -func TestPriorityOwnParent(t *testing.T) { - streams := make(map[uint32]*stream) - a := &stream{ - parent: nil, - weight: 16, - } - streams[1] = a - b := &stream{ - parent: a, - weight: 16, - } - streams[2] = b - adjustStreamPriority(streams, 1, PriorityParam{ - Weight: 20, - StreamDep: 1, - }) - if a.parent != nil { - t.Errorf("Expected A's parent to be nil") - } - if a.weight != 20 { - t.Errorf("Expected A's weight to be 20; got %d", a.weight) - } - if b.parent != a { - t.Errorf("Expected B's parent to be A") - } - if b.weight != 16 { - t.Errorf("Expected B's weight to be 16; got %d", b.weight) - } - -} diff --git a/vendor/golang.org/x/net/http2/server.go b/vendor/golang.org/x/net/http2/server.go index 8206fa79d..0b6b4b08d 100644 --- a/vendor/golang.org/x/net/http2/server.go +++ b/vendor/golang.org/x/net/http2/server.go @@ -2,17 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO: replace all <-sc.doneServing with reads from the stream's cw -// instead, and make sure that on close we close all open -// streams. then remove doneServing? - -// TODO: re-audit GOAWAY support. Consider each incoming frame type and -// whether it should be ignored during graceful shutdown. - -// TODO: disconnect idle clients. GFE seems to do 4 minutes. make -// configurable? or maximum number of idle clients and remove the -// oldest? - // TODO: turn off the serve goroutine when idle, so // an idle conn only has the readFrames goroutine active. (which could // also be optimized probably to pin less memory in crypto/tls). This @@ -44,6 +33,7 @@ import ( "fmt" "io" "log" + "math" "net" "net/http" "net/textproto" @@ -114,6 +104,15 @@ type Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() WriteScheduler } func (s *Server) maxReadFrameSize() uint32 { @@ -136,9 +135,15 @@ func (s *Server) maxConcurrentStreams() uint32 { // // ConfigureServer must be called before s begins serving. func ConfigureServer(s *http.Server, conf *Server) error { + if s == nil { + panic("nil *http.Server") + } if conf == nil { conf = new(Server) } + if err := configureServer18(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -183,9 +188,6 @@ func ConfigureServer(s *http.Server, conf *Server) error { if !haveNPN { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) } - // h2-14 is temporary (as of 2015-03-05) while we wait for all browsers - // to switch to "h2". - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} @@ -200,7 +202,6 @@ func ConfigureServer(s *http.Server, conf *Server) error { }) } s.TLSNextProto[NextProtoTLS] = protoHandler - s.TLSNextProto["h2-14"] = protoHandler // temporary; see above. return nil } @@ -254,29 +255,35 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { defer cancel() sc := &serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*stream), - readFrameCh: make(chan readFrameResult), - wantWriteFrameCh: make(chan frameWriteMsg, 8), - wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync - bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way - doneServing: make(chan struct{}), - advMaxStreams: s.maxConcurrentStreams(), - writeSched: writeScheduler{ - maxFrameSize: initialMaxFrameSize, - }, + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*stream), + readFrameCh: make(chan readFrameResult), + wantWriteFrameCh: make(chan FrameWriteRequest, 8), + wantStartPushCh: make(chan startPushRequest, 8), + wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync + bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" + advMaxStreams: s.maxConcurrentStreams(), initialWindowSize: initialWindowSize, + maxFrameSize: initialMaxFrameSize, headerTableSize: initialHeaderTableSize, serveG: newGoroutineLock(), pushEnabled: true, } + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = NewRandomWriteScheduler() + } + sc.flow.add(initialWindowSize) sc.inflow.add(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -356,16 +363,18 @@ type serverConn struct { handler http.Handler baseCtx contextContext framer *Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan frameWriteMsg // from handlers -> serve - wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop - flow flow // conn-wide (not stream-specific) outbound flow control - inflow flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan FrameWriteRequest // from handlers -> serve + wantStartPushCh chan startPushRequest // from handlers -> serve + wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow flow // conn-wide (not stream-specific) outbound flow control + inflow flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string + writeSched WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): serveG goroutineLock // used to verify funcs are on serve() @@ -375,22 +384,27 @@ type serverConn struct { unackedSettings int // how many SETTINGS have we sent without ACKs? clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curOpenStreams uint32 // client's number of open streams - maxStreamID uint32 // max ever seen + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes streams map[uint32]*stream initialWindowSize int32 + maxFrameSize int32 headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush - writeSched writeScheduler - inGoAway bool // we've started to or sent GOAWAY - needToSendGoAway bool // we need to schedule a GOAWAY frame write + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used - freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -434,11 +448,11 @@ type stream struct { numTrailerValues int64 weight uint8 state streamState - sentReset bool // only true once detached from streams map - gotReset bool // only true once detacted from streams map - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - reqBuf []byte + sentReset bool // only true once detached from streams map + gotReset bool // only true once detacted from streams map + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte // if non-nil, body pipe buffer to return later at EOF trailer http.Header // accumulated trailers reqTrailer http.Header // handler's Request.Trailer @@ -453,7 +467,7 @@ func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { func (sc *serverConn) state(streamID uint32) (streamState, *stream) { sc.serveG.check() - // http://http2.github.io/http2-spec/#rfc.section.5.1 + // http://tools.ietf.org/html/rfc7540#section-5.1 if st, ok := sc.streams[streamID]; ok { return st.state, st } @@ -463,8 +477,14 @@ func (sc *serverConn) state(streamID uint32) (streamState, *stream) { // a client sends a HEADERS frame on stream 7 without ever sending a // frame on stream 5, then stream 5 transitions to the "closed" // state when the first frame for stream 7 is sent or received." - if streamID <= sc.maxStreamID { - return stateClosed, nil + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return stateClosed, nil + } } return stateIdle, nil } @@ -603,17 +623,17 @@ func (sc *serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type frameWriteResult struct { - wm frameWriteMsg // what was written (or attempted) - err error // result of the writeFrame call + wr FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) { - err := wm.write.writeFrame(sc) - sc.wroteFrameCh <- frameWriteResult{wm, err} +func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- frameWriteResult{wr, err} } func (sc *serverConn) closeAllStreamsOnConnClose() { @@ -657,7 +677,7 @@ func (sc *serverConn) serve() { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(frameWriteMsg{ + sc.writeFrame(FrameWriteRequest{ write: writeSettings{ {SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {SettingMaxConcurrentStreams, sc.advMaxStreams}, @@ -682,6 +702,17 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + + var gracefulShutdownCh <-chan struct{} + if sc.hs != nil { + gracefulShutdownCh = h1ServerShutdownChan(sc.hs) + } + go sc.readFrames() // closed by defer sc.conn.Close above settingsTimer := time.NewTimer(firstSettingsTimeout) @@ -689,8 +720,10 @@ func (sc *serverConn) serve() { for { loopNum++ select { - case wm := <-sc.wantWriteFrameCh: - sc.writeFrame(wm) + case wr := <-sc.wantWriteFrameCh: + sc.writeFrame(wr) + case spr := <-sc.wantStartPushCh: + sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -707,12 +740,22 @@ func (sc *serverConn) serve() { case <-settingsTimer.C: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return + case <-gracefulShutdownCh: + gracefulShutdownCh = nil + sc.startGracefulShutdown() case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } + + if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame { + return + } } } @@ -760,7 +803,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea ch := errChanPool.Get().(chan error) writeArg := writeDataPool.Get().(*writeData) *writeArg = writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(frameWriteMsg{ + err := sc.writeFrameFromHandler(FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -796,17 +839,17 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea return err } -// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts // if the connection has gone away. // // This must not be run from the serve goroutine itself, else it might // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) error { +func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error { sc.serveG.checkNotOn() // NOT select { - case sc.wantWriteFrameCh <- wm: + case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: // Serve loop is gone. @@ -823,38 +866,38 @@ func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) error { // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *serverConn) writeFrame(wm frameWriteMsg) { +func (sc *serverConn) writeFrame(wr FrameWriteRequest) { sc.serveG.check() var ignoreWrite bool // Don't send a 100-continue response if we've already sent headers. // See golang.org/issue/14030. - switch wm.write.(type) { + switch wr.write.(type) { case *writeResHeaders: - wm.stream.wroteHeaders = true + wr.stream.wroteHeaders = true case write100ContinueHeadersFrame: - if wm.stream.wroteHeaders { + if wr.stream.wroteHeaders { ignoreWrite = true } } if !ignoreWrite { - sc.writeSched.add(wm) + sc.writeSched.Push(wr) } sc.scheduleFrameWrite() } -// startFrameWrite starts a goroutine to write wm (in a separate +// startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wm. -func (sc *serverConn) startFrameWrite(wm frameWriteMsg) { +// serve goroutine's state about the world, updated from info in wr. +func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") } - st := wm.stream + st := wr.stream if st != nil { switch st.state { case stateHalfClosedLocal: @@ -865,13 +908,31 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) { sc.scheduleFrameWrite() return } - panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wr)) + } + } + if wpp, ok := wr.write.(*writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + if wr.done != nil { + wr.done <- err + } + return } } sc.writingFrame = true sc.needsFrameFlush = true - go sc.writeFrameAsync(wm) + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(frameWriteResult{wr, err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } } // errHandlerPanicked is the error given to any callers blocked in a read from @@ -887,25 +948,26 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { panic("internal error: expected to be already writing a frame") } sc.writingFrame = false + sc.writingFrameAsync = false - wm := res.wm - st := wm.stream + wr := res.wr + st := wr.stream - closeStream := endsStream(wm.write) + closeStream := endsStream(wr.write) - if _, ok := wm.write.(handlerPanicRST); ok { + if _, ok := wr.write.(handlerPanicRST); ok { sc.closeStream(st, errHandlerPanicked) } // Reply (if requested) to the blocked ServeHTTP goroutine. - if ch := wm.done; ch != nil { + if ch := wr.done; ch != nil { select { case ch <- res.err: default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) } } - wm.write = nil // prevent use (assume it's tainted after wm.done send) + wr.write = nil // prevent use (assume it's tainted after wr.done send) if closeStream { if st == nil { @@ -916,11 +978,11 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { // Here we would go to stateHalfClosedLocal in // theory, but since our handler is done and // the net/http package provides no mechanism - // for finishing writing to a ResponseWriter - // while still reading data (see possible TODO - // at top of this file), we go into closed - // state here anyway, after telling the peer - // we're hanging up on them. + // for closing a ResponseWriter while still + // reading data (see possible TODO at top of + // this file), we go into closed state here + // anyway, after telling the peer we're + // hanging up on them. st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream errCancel := streamError(st.id, ErrCodeCancel) sc.resetStream(errCancel) @@ -946,47 +1008,68 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { // flush the write buffer. func (sc *serverConn) scheduleFrameWrite() { sc.serveG.check() - if sc.writingFrame { - return - } - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(frameWriteMsg{ - write: &writeGoAway{ - maxStreamID: sc.maxStreamID, - code: sc.goAwayCode, - }, - }) - return - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(frameWriteMsg{write: writeSettingsAck{}}) + if sc.writingFrame || sc.inFrameScheduleLoop { return } - if !sc.inGoAway { - if wm, ok := sc.writeSched.take(); ok { - sc.startFrameWrite(wm) - return + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(FrameWriteRequest{ + write: &writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}}) + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true + continue + } + break } - if sc.needsFrameFlush { - sc.startFrameWrite(frameWriteMsg{write: flushFrameWriter{}}) - sc.needsFrameFlush = false // after startFrameWrite, since it sets this true - return - } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the +// client we're gracefully shutting down. The connection isn't closed +// until all current streams are done. +func (sc *serverConn) startGracefulShutdown() { + sc.goAwayIn(ErrCodeNo, 0) } func (sc *serverConn) goAway(code ErrCode) { sc.serveG.check() - if sc.inGoAway { - return - } + var forceCloseIn time.Duration if code != ErrCodeNo { - sc.shutDownIn(250 * time.Millisecond) + forceCloseIn = 250 * time.Millisecond } else { // TODO: configurable - sc.shutDownIn(1 * time.Second) + forceCloseIn = 1 * time.Second + } + sc.goAwayIn(code, forceCloseIn) +} + +func (sc *serverConn) goAwayIn(code ErrCode, forceCloseIn time.Duration) { + sc.serveG.check() + if sc.inGoAway { + return + } + if forceCloseIn != 0 { + sc.shutDownIn(forceCloseIn) } sc.inGoAway = true sc.needToSendGoAway = true @@ -1002,7 +1085,7 @@ func (sc *serverConn) shutDownIn(d time.Duration) { func (sc *serverConn) resetStream(se StreamError) { sc.serveG.check() - sc.writeFrame(frameWriteMsg{write: se}) + sc.writeFrame(FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { st.sentReset = true sc.closeStream(st, se) @@ -1090,6 +1173,8 @@ func (sc *serverConn) processFrame(f Frame) error { return sc.processResetStream(f) case *PriorityFrame: return sc.processPriority(f) + case *GoAwayFrame: + return sc.processGoAway(f) case *PushPromiseFrame: // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. @@ -1115,7 +1200,10 @@ func (sc *serverConn) processPing(f *PingFrame) error { // PROTOCOL_ERROR." return ConnectionError(ErrCodeProtocol) } - sc.writeFrame(frameWriteMsg{write: writePingAck{f}}) + if sc.inGoAway && sc.goAwayCode != ErrCodeNo { + return nil + } + sc.writeFrame(FrameWriteRequest{write: writePingAck{f}}) return nil } @@ -1123,7 +1211,14 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { sc.serveG.check() switch { case f.StreamID != 0: // stream-level flow control - st := sc.streams[f.StreamID] + state, st := sc.state(f.StreamID) + if state == stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return ConnectionError(ErrCodeProtocol) + } if st == nil { // "WINDOW_UPDATE can be sent by a peer that has sent a // frame bearing the END_STREAM flag. This means that a @@ -1170,11 +1265,21 @@ func (sc *serverConn) closeStream(st *stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = stateClosed - sc.curOpenStreams-- - if sc.curOpenStreams == 0 { - sc.setConnState(http.StateIdle) + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- } delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(http.StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdown() + } + } if p := st.body; p != nil { // Return any buffered unread bytes worth of conn-level flow control. // See golang.org/issue/16481 @@ -1183,19 +1288,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { p.CloseWithError(err) } st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc - sc.writeSched.forgetStream(st.id) - if st.reqBuf != nil { - // Stash this request body buffer (64k) away for reuse - // by a future POST/PUT/etc. - // - // TODO(bradfitz): share on the server? sync.Pool? - // Server requires locks and might hurt contention. - // sync.Pool might work, or might be worse, depending - // on goroutine CPU migrations. (get and put on - // separate CPUs). Maybe a mix of strategies. But - // this is an easy win for now. - sc.freeRequestBodyBuf = st.reqBuf - } + sc.writeSched.CloseStream(st.id) } func (sc *serverConn) processSettings(f *SettingsFrame) error { @@ -1237,7 +1330,7 @@ func (sc *serverConn) processSetting(s Setting) error { case SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case SettingMaxFrameSize: - sc.writeSched.maxFrameSize = s.Val + sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 case SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: @@ -1281,14 +1374,24 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *serverConn) processData(f *DataFrame) error { sc.serveG.check() + if sc.inGoAway && sc.goAwayCode != ErrCodeNo { + return nil + } data := f.Data() // "If a DATA frame is received whose stream is not in "open" // or "half closed (local)" state, the recipient MUST respond // with a stream error (Section 5.4.2) of type STREAM_CLOSED." id := f.Header().StreamID - st, ok := sc.streams[id] - if !ok || st.state != stateOpen || st.gotTrailerHeader { + state, st := sc.state(id) + if id == 0 || state == stateIdle { + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." + return ConnectionError(ErrCodeProtocol) + } + if st == nil || state != stateOpen || st.gotTrailerHeader { // This includes sending a RST_STREAM if the stream is // in stateHalfClosedLocal (which currently means that // the http.Handler returned, so it's done reading & @@ -1350,6 +1453,25 @@ func (sc *serverConn) processData(f *DataFrame) error { return nil } +func (sc *serverConn) processGoAway(f *GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdown() + // http://tools.ietf.org/html/rfc7540#section-6.8 + // We should not create any new streams, which means we should disable push. + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *stream) isPushed() bool { + return st.id%2 == 0 +} + // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). func (st *stream) endStream() { @@ -1379,12 +1501,12 @@ func (st *stream) copyTrailersToHandlerRequest() { func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { sc.serveG.check() - id := f.Header().StreamID + id := f.StreamID if sc.inGoAway { // Ignore. return nil } - // http://http2.github.io/http2-spec/#rfc.section.5.1.1 + // http://tools.ietf.org/html/rfc7540#section-5.1.1 // Streams initiated by a client MUST use odd-numbered stream // identifiers. [...] An endpoint that receives an unexpected // stream identifier MUST respond with a connection error @@ -1396,8 +1518,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // send a trailer for an open one. If we already have a stream // open, let it process its own HEADERS frame (trailers at this // point, if it's valid). - st := sc.streams[f.Header().StreamID] - if st != nil { + if st := sc.streams[f.StreamID]; st != nil { return st.processTrailerHeaders(f) } @@ -1406,54 +1527,45 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // endpoint has opened or reserved. [...] An endpoint that // receives an unexpected stream identifier MUST respond with // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - if id <= sc.maxStreamID { + if id <= sc.maxClientStreamID { return ConnectionError(ErrCodeProtocol) } - sc.maxStreamID = id + sc.maxClientStreamID = id - ctx, cancelCtx := contextWithCancel(sc.baseCtx) - st = &stream{ - sc: sc, - id: id, - state: stateOpen, - ctx: ctx, - cancelCtx: cancelCtx, + if sc.idleTimer != nil { + sc.idleTimer.Stop() } - if f.StreamEnded() { - st.state = stateHalfClosedRemote - } - st.cw.Init() - st.flow.conn = &sc.flow // link to conn-level counter - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow // link to conn-level counter - st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings - - sc.streams[id] = st - if f.HasPriority() { - adjustStreamPriority(sc.streams, st.id, f.Priority) - } - sc.curOpenStreams++ - if sc.curOpenStreams == 1 { - sc.setConnState(http.StateActive) - } - if sc.curOpenStreams > sc.advMaxStreams { - // "Endpoints MUST NOT exceed the limit set by their - // peer. An endpoint that receives a HEADERS frame - // that causes their advertised concurrent stream - // limit to be exceeded MUST treat this as a stream - // error (Section 5.4.2) of type PROTOCOL_ERROR or - // REFUSED_STREAM." + // http://tools.ietf.org/html/rfc7540#section-5.1.2 + // [...] Endpoints MUST NOT exceed the limit set by their peer. An + // endpoint that receives a HEADERS frame that causes their + // advertised concurrent stream limit to be exceeded MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR + // or REFUSED_STREAM. + if sc.curClientStreams+1 > sc.advMaxStreams { if sc.unackedSettings == 0 { // They should know better. - return streamError(st.id, ErrCodeProtocol) + return streamError(id, ErrCodeProtocol) } // Assume it's a network race, where they just haven't // received our last SETTINGS update. But actually // this can't happen yet, because we don't yet provide // a way for users to adjust server parameters at // runtime. - return streamError(st.id, ErrCodeRefusedStream) + return streamError(id, ErrCodeRefusedStream) + } + + initialState := stateOpen + if f.StreamEnded() { + initialState = stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) + + if f.HasPriority() { + if err := checkPriority(f.StreamID, f.Priority); err != nil { + return err + } + sc.writeSched.AdjustStream(st.id, f.Priority) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -1471,10 +1583,21 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { if f.Truncated { // Their header list was too long. Send a 431 error. handler = handleHeaderListTooLong - } else if err := checkValidHTTP2Request(req); err != nil { + } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = new400Handler(err) } + // The net/http package sets the read deadline from the + // http.Server.ReadTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already + // set. Disarm it here after the request headers are read, + // similar to how the http1 server works. Here it's + // technically more like the http1 Server's ReadHeaderTimeout + // (in Go 1.8), though. That's a more sane option anyway. + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + go sc.runHandler(rw, req, handler) return nil } @@ -1509,62 +1632,78 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { return nil } +func checkPriority(streamID uint32, p PriorityParam) error { + if streamID == p.StreamDep { + // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." + // Section 5.3.3 says that a stream can depend on one of its dependencies, + // so it's only self-dependencies that are forbidden. + return streamError(streamID, ErrCodeProtocol) + } + return nil +} + func (sc *serverConn) processPriority(f *PriorityFrame) error { - adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam) + if sc.inGoAway { + return nil + } + if err := checkPriority(f.StreamID, f.PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam) return nil } -func adjustStreamPriority(streams map[uint32]*stream, streamID uint32, priority PriorityParam) { - st, ok := streams[streamID] - if !ok { - // TODO: not quite correct (this streamID might - // already exist in the dep tree, but be closed), but - // close enough for now. - return +func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") } - st.weight = priority.Weight - parent := streams[priority.StreamDep] // might be nil - if parent == st { - // if client tries to set this stream to be the parent of itself - // ignore and keep going - return + + ctx, cancelCtx := contextWithCancel(sc.baseCtx) + st := &stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, } + st.cw.Init() + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings - // section 5.3.3: If a stream is made dependent on one of its - // own dependencies, the formerly dependent stream is first - // moved to be dependent on the reprioritized stream's previous - // parent. The moved dependency retains its weight. - for piter := parent; piter != nil; piter = piter.parent { - if piter == st { - parent.parent = st.parent - break - } + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ } - st.parent = parent - if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { - for _, openStream := range streams { - if openStream != st && openStream.parent == st.parent { - openStream.parent = st - } - } + if sc.curClientStreams+sc.curPushedStreams == 1 { + sc.setConnState(http.StateActive) } + + return st } func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { sc.serveG.check() - method := f.PseudoValue("method") - path := f.PseudoValue("path") - scheme := f.PseudoValue("scheme") - authority := f.PseudoValue("authority") + rp := requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } - isConnect := method == "CONNECT" + isConnect := rp.method == "CONNECT" if isConnect { - if path != "" || scheme != "" || authority == "" { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } - } else if method == "" || path == "" || - (scheme != "https" && scheme != "http") { + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: // // Malformed requests or responses that are detected @@ -1579,36 +1718,64 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res } bodyOpen := !f.StreamEnded() - if method == "HEAD" && bodyOpen { + if rp.method == "HEAD" && bodyOpen { // HEAD requests can't have bodies return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } - var tlsState *tls.ConnectionState // nil if not scheme https - if scheme == "https" { - tlsState = sc.tlsState + rp.header = make(http.Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") } - header := make(http.Header) - for _, hf := range f.RegularFields() { - header.Add(sc.canonicalHeader(hf.Name), hf.Value) + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err + } + if bodyOpen { + st.reqBuf = getRequestBodyBuf() + req.Body.(*requestBody).pipe = &pipe{ + b: &fixedBuffer{buf: st.reqBuf}, + } + + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } } + return rw, req, nil +} - if authority == "" { - authority = header.Get("Host") +type requestParam struct { + method string + scheme, authority, path string + header http.Header +} + +func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState } - needsContinue := header.Get("Expect") == "100-continue" + + needsContinue := rp.header.Get("Expect") == "100-continue" if needsContinue { - header.Del("Expect") + rp.header.Del("Expect") } // Merge Cookie headers into one "; "-delimited value. - if cookies := header["Cookie"]; len(cookies) > 1 { - header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer http.Header - for _, v := range header["Trailer"] { + for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { key = http.CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -1623,57 +1790,42 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res } } } - delete(header, "Trailer") + delete(rp.header, "Trailer") - body := &requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } var url_ *url.URL var requestURI string - if isConnect { - url_ = &url.URL{Host: authority} - requestURI = authority // mimic HTTP/1 server behavior + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority // mimic HTTP/1 server behavior } else { var err error - url_, err = url.ParseRequestURI(path) + url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, streamError(f.StreamID, ErrCodeProtocol) + return nil, nil, streamError(st.id, ErrCodeProtocol) } - requestURI = path + requestURI = rp.path + } + + body := &requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, } req := &http.Request{ - Method: method, + Method: rp.method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: header, + Header: rp.header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: authority, + Host: rp.authority, Body: body, Trailer: trailer, } req = requestWithContext(req, st.ctx) - if bodyOpen { - // Disabled, per golang.org/issue/14960: - // st.reqBuf = sc.getRequestBodyBuf() - // TODO: remove this 64k of garbage per request (again, but without a data race): - buf := make([]byte, initialWindowSize) - - body.pipe = &pipe{ - b: &fixedBuffer{buf: buf}, - } - - if vv, ok := header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) - } else { - req.ContentLength = -1 - } - } rws := responseWriterStatePool.Get().(*responseWriterState) bwSave := rws.bw @@ -1689,13 +1841,22 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return rw, req, nil } -func (sc *serverConn) getRequestBodyBuf() []byte { - sc.serveG.check() - if buf := sc.freeRequestBodyBuf; buf != nil { - sc.freeRequestBodyBuf = nil - return buf +var reqBodyCache = make(chan []byte, 8) + +func getRequestBodyBuf() []byte { + select { + case b := <-reqBodyCache: + return b + default: + return make([]byte, initialWindowSize) + } +} + +func putRequestBodyBuf(b []byte) { + select { + case reqBodyCache <- b: + default: } - return make([]byte, initialWindowSize) } // Run on its own goroutine. @@ -1705,15 +1866,17 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler rw.rws.stream.cancelCtx() if didPanic { e := recover() - // Same as net/http: - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.writeFrameFromHandler(frameWriteMsg{ + sc.writeFrameFromHandler(FrameWriteRequest{ write: handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + // Same as net/http: + if shouldLogPanic(e) { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } return } rw.handlerDone() @@ -1744,7 +1907,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro // mutates it. errc = errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(frameWriteMsg{ + if err := sc.writeFrameFromHandler(FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -1767,7 +1930,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro // called from handler goroutines. func (sc *serverConn) write100ContinueHeaders(st *stream) { - sc.writeFrameFromHandler(frameWriteMsg{ + sc.writeFrameFromHandler(FrameWriteRequest{ write: write100ContinueHeadersFrame{st.id}, stream: st, }) @@ -1783,11 +1946,19 @@ type bodyReadMsg struct { // called from handler goroutines. // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. -func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int) { +func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) { sc.serveG.checkNotOn() // NOT on - select { - case sc.bodyReadCh <- bodyReadMsg{st, n}: - case <-sc.doneServing: + if n > 0 { + select { + case sc.bodyReadCh <- bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } + if err == io.EOF { + if buf := st.reqBuf; buf != nil { + st.reqBuf = nil // shouldn't matter; field unused by other + putRequestBodyBuf(buf) + } } } @@ -1830,7 +2001,7 @@ func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(frameWriteMsg{ + sc.writeFrame(FrameWriteRequest{ write: writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) @@ -1845,16 +2016,19 @@ func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { } } +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. type requestBody struct { stream *stream conn *serverConn - closed bool + closed bool // for use by Close only + sawEOF bool // for use by Read only pipe *pipe // non-nil if we have a HTTP entity message body needsContinue bool // need to send a 100-continue } func (b *requestBody) Close() error { - if b.pipe != nil { + if b.pipe != nil && !b.closed { b.pipe.BreakWithError(errClosedBody) } b.closed = true @@ -1866,13 +2040,17 @@ func (b *requestBody) Read(p []byte) (n int, err error) { b.needsContinue = false b.conn.write100ContinueHeaders(b.stream) } - if b.pipe == nil { + if b.pipe == nil || b.sawEOF { return 0, io.EOF } n, err = b.pipe.Read(p) - if n > 0 { - b.conn.noteBodyReadFromHandler(b.stream, n) + if err == io.EOF { + b.sawEOF = true } + if b.conn == nil && inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) return } @@ -2110,8 +2288,9 @@ func (w *responseWriter) CloseNotify() <-chan bool { if ch == nil { ch = make(chan bool, 1) rws.closeNotifierCh = ch + cw := rws.stream.cw go func() { - rws.stream.cw.Wait() // wait for close + cw.Wait() // wait for close ch <- true }() } @@ -2207,6 +2386,200 @@ func (w *responseWriter) handlerDone() { responseWriterStatePool.Put(rws) } +// Push errors. +var ( + ErrRecursivePush = errors.New("http2: recursive push not allowed") + ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +// pushOptions is the internal version of http.PushOptions, which we +// cannot include here because it's only defined in Go 1.8 and later. +type pushOptions struct { + Method string + Header http.Header +} + +func (w *responseWriter) push(target string, opts pushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." + // http://tools.ietf.org/html/rfc7540#section-6.6 + if st.isPushed() { + return ErrRecursivePush + } + + // Default options. + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = http.Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + // Validate the request. + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + // These headers are meaningful only if the request has a body, + // but PUSH_PROMISE requests cannot have a body. + // http://tools.ietf.org/html/rfc7540#section-8.2 + // Also disallow Host, since the promised URL must be absolute. + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + // The RFC effectively limits promised requests to GET and HEAD: + // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" + // http://tools.ietf.org/html/rfc7540#section-8.2 + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: cloneHeader(opts.Header), + done: errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return errClientDisconnected + case <-st.cw: + return errStreamClosed + case sc.wantStartPushCh <- msg: + } + + select { + case <-sc.doneServing: + return errClientDisconnected + case <-st.cw: + return errStreamClosed + case err := <-msg.done: + errChanPool.Put(msg.done) + return err + } +} + +type startPushRequest struct { + parent *stream + method string + url *url.URL + header http.Header + done chan error +} + +func (sc *serverConn) startPush(msg startPushRequest) { + sc.serveG.check() + + // http://tools.ietf.org/html/rfc7540#section-6.6. + // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that + // is in either the "open" or "half-closed (remote)" state. + if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote { + // responseWriter.Push checks that the stream is peer-initiaed. + msg.done <- errStreamClosed + return + } + + // http://tools.ietf.org/html/rfc7540#section-6.6. + if !sc.pushEnabled { + msg.done <- http.ErrNotSupported + return + } + + // PUSH_PROMISE frames must be sent in increasing order by stream ID, so + // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE + // is written. Once the ID is allocated, we start the request handler. + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + // Check this again, just in case. Technically, we might have received + // an updated SETTINGS by the time we got around to writing this frame. + if !sc.pushEnabled { + return 0, http.ErrNotSupported + } + // http://tools.ietf.org/html/rfc7540#section-6.5.2. + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, ErrPushLimitReached + } + + // http://tools.ietf.org/html/rfc7540#section-5.1.1. + // Streams initiated by the server MUST use even-numbered identifiers. + // A server that is unable to establish a new stream identifier can send a GOAWAY + // frame so that the client is forced to open a new connection for new streams. + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdown() + return 0, ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + // http://tools.ietf.org/html/rfc7540#section-8.2. + // Strictly speaking, the new stream should start in "reserved (local)", then + // transition to "half closed (remote)" after sending the initial HEADERS, but + // we start in "half closed (remote)" for simplicity. + // See further comments at the definition of stateHalfClosedRemote. + promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: msg.header, + }) + if err != nil { + // Should not happen, since we've already validated msg.url. + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(FrameWriteRequest{ + write: &writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 2616 section 2.1 and calls fn for each non-empty element. func foreachHeaderElement(v string, fn func(string)) { @@ -2234,16 +2607,16 @@ var connHeaders = []string{ "Upgrade", } -// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func checkValidHTTP2Request(req *http.Request) error { - for _, h := range connHeaders { - if _, ok := req.Header[h]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", h) +func checkValidHTTP2RequestHeaders(h http.Header) error { + for _, k := range connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) } } - te := req.Header["Te"] + te := h["Te"] if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) } @@ -2290,3 +2663,42 @@ var badTrailer = map[string]bool{ "Transfer-Encoding": true, "Www-Authenticate": true, } + +// h1ServerShutdownChan returns a channel that will be closed when the +// provided *http.Server wants to shut down. +// +// This is a somewhat hacky way to get at http1 innards. It works +// when the http2 code is bundled into the net/http package in the +// standard library. The alternatives ended up making the cmd/go tool +// depend on http Servers. This is the lightest option for now. +// This is tested via the TestServeShutdown* tests in net/http. +func h1ServerShutdownChan(hs *http.Server) <-chan struct{} { + if fn := testh1ServerShutdownChan; fn != nil { + return fn(hs) + } + var x interface{} = hs + type I interface { + getDoneChan() <-chan struct{} + } + if hs, ok := x.(I); ok { + return hs.getDoneChan() + } + return nil +} + +// optional test hook for h1ServerShutdownChan. +var testh1ServerShutdownChan func(hs *http.Server) <-chan struct{} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func h1ServerKeepAlivesDisabled(hs *http.Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} diff --git a/vendor/golang.org/x/net/http2/server_push_test.go b/vendor/golang.org/x/net/http2/server_push_test.go new file mode 100644 index 000000000..3fea20870 --- /dev/null +++ b/vendor/golang.org/x/net/http2/server_push_test.go @@ -0,0 +1,470 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package http2 + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "reflect" + "strconv" + "sync" + "testing" + "time" +) + +func TestServer_Push_Success(t *testing.T) { + const ( + mainBody = "<html>index page</html>" + pushedBody = "<html>pushed page</html>" + userAgent = "testagent" + cookie = "testcookie" + ) + + var stURL string + checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error { + if got, want := r.Method, wantMethod; got != want { + return fmt.Errorf("promised Req.Method=%q, want %q", got, want) + } + if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) { + return fmt.Errorf("promised Req.Header=%q, want %q", got, want) + } + if got, want := "https://"+r.Host, stURL; got != want { + return fmt.Errorf("promised Req.Host=%q, want %q", got, want) + } + if r.Body == nil { + return fmt.Errorf("nil Body") + } + if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 { + return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err) + } + return nil + } + + errc := make(chan error, 3) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RequestURI() { + case "/": + // Push "/pushed?get" as a GET request, using an absolute URL. + opt := &http.PushOptions{ + Header: http.Header{ + "User-Agent": {userAgent}, + }, + } + if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil { + errc <- fmt.Errorf("error pushing /pushed?get: %v", err) + return + } + // Push "/pushed?head" as a HEAD request, using a path. + opt = &http.PushOptions{ + Method: "HEAD", + Header: http.Header{ + "User-Agent": {userAgent}, + "Cookie": {cookie}, + }, + } + if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil { + errc <- fmt.Errorf("error pushing /pushed?head: %v", err) + return + } + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(mainBody))) + w.WriteHeader(200) + io.WriteString(w, mainBody) + errc <- nil + + case "/pushed?get": + wantH := http.Header{} + wantH.Set("User-Agent", userAgent) + if err := checkPromisedReq(r, "GET", wantH); err != nil { + errc <- fmt.Errorf("/pushed?get: %v", err) + return + } + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody))) + w.WriteHeader(200) + io.WriteString(w, pushedBody) + errc <- nil + + case "/pushed?head": + wantH := http.Header{} + wantH.Set("User-Agent", userAgent) + wantH.Set("Cookie", cookie) + if err := checkPromisedReq(r, "HEAD", wantH); err != nil { + errc <- fmt.Errorf("/pushed?head: %v", err) + return + } + w.WriteHeader(204) + errc <- nil + + default: + errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) + } + }) + stURL = st.ts.URL + + // Send one request, which should push two responses. + st.greet() + getSlash(st) + for k := 0; k < 3; k++ { + select { + case <-time.After(2 * time.Second): + t.Errorf("timeout waiting for handler %d to finish", k) + case err := <-errc: + if err != nil { + t.Fatal(err) + } + } + } + + checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error { + pp, ok := f.(*PushPromiseFrame) + if !ok { + return fmt.Errorf("got a %T; want *PushPromiseFrame", f) + } + if !pp.HeadersEnded() { + return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame") + } + if got, want := pp.PromiseID, promiseID; got != want { + return fmt.Errorf("got PromiseID %v; want %v", got, want) + } + gotH := st.decodeHeader(pp.HeaderBlockFragment()) + if !reflect.DeepEqual(gotH, wantH) { + return fmt.Errorf("got promised headers %v; want %v", gotH, wantH) + } + return nil + } + checkHeaders := func(f Frame, wantH [][2]string) error { + hf, ok := f.(*HeadersFrame) + if !ok { + return fmt.Errorf("got a %T; want *HeadersFrame", f) + } + gotH := st.decodeHeader(hf.HeaderBlockFragment()) + if !reflect.DeepEqual(gotH, wantH) { + return fmt.Errorf("got response headers %v; want %v", gotH, wantH) + } + return nil + } + checkData := func(f Frame, wantData string) error { + df, ok := f.(*DataFrame) + if !ok { + return fmt.Errorf("got a %T; want *DataFrame", f) + } + if gotData := string(df.Data()); gotData != wantData { + return fmt.Errorf("got response data %q; want %q", gotData, wantData) + } + return nil + } + + // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA + // Stream 2 has HEADERS + DATA + // Stream 4 has HEADERS + expected := map[uint32][]func(Frame) error{ + 1: { + func(f Frame) error { + return checkPushPromise(f, 2, [][2]string{ + {":method", "GET"}, + {":scheme", "https"}, + {":authority", st.ts.Listener.Addr().String()}, + {":path", "/pushed?get"}, + {"user-agent", userAgent}, + }) + }, + func(f Frame) error { + return checkPushPromise(f, 4, [][2]string{ + {":method", "HEAD"}, + {":scheme", "https"}, + {":authority", st.ts.Listener.Addr().String()}, + {":path", "/pushed?head"}, + {"cookie", cookie}, + {"user-agent", userAgent}, + }) + }, + func(f Frame) error { + return checkHeaders(f, [][2]string{ + {":status", "200"}, + {"content-type", "text/html"}, + {"content-length", strconv.Itoa(len(mainBody))}, + }) + }, + func(f Frame) error { + return checkData(f, mainBody) + }, + }, + 2: { + func(f Frame) error { + return checkHeaders(f, [][2]string{ + {":status", "200"}, + {"content-type", "text/html"}, + {"content-length", strconv.Itoa(len(pushedBody))}, + }) + }, + func(f Frame) error { + return checkData(f, pushedBody) + }, + }, + 4: { + func(f Frame) error { + return checkHeaders(f, [][2]string{ + {":status", "204"}, + }) + }, + }, + } + + consumed := map[uint32]int{} + for k := 0; len(expected) > 0; k++ { + f, err := st.readFrame() + if err != nil { + for id, left := range expected { + t.Errorf("stream %d: missing %d frames", id, len(left)) + } + t.Fatalf("readFrame %d: %v", k, err) + } + id := f.Header().StreamID + label := fmt.Sprintf("stream %d, frame %d", id, consumed[id]) + if len(expected[id]) == 0 { + t.Fatalf("%s: unexpected frame %#+v", label, f) + } + check := expected[id][0] + expected[id] = expected[id][1:] + if len(expected[id]) == 0 { + delete(expected, id) + } + if err := check(f); err != nil { + t.Fatalf("%s: %v", label, err) + } + consumed[id]++ + } +} + +func TestServer_Push_RejectRecursivePush(t *testing.T) { + // Expect two requests, but might get three if there's a bug and the second push succeeds. + errc := make(chan error, 3) + handler := func(w http.ResponseWriter, r *http.Request) error { + baseURL := "https://" + r.Host + switch r.URL.Path { + case "/": + if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil { + return fmt.Errorf("first Push()=%v, want nil", err) + } + return nil + + case "/push1": + if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want { + return fmt.Errorf("Push()=%v, want %v", got, want) + } + return nil + + default: + return fmt.Errorf("unexpected path: %q", r.URL.Path) + } + } + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + errc <- handler(w, r) + }) + defer st.Close() + st.greet() + getSlash(st) + if err := <-errc; err != nil { + t.Errorf("First request failed: %v", err) + } + if err := <-errc; err != nil { + t.Errorf("Second request failed: %v", err) + } +} + +func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { + // Expect one request, but might get two if there's a bug and the push succeeds. + errc := make(chan error, 2) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + errc <- doPush(w.(http.Pusher), r) + }) + defer st.Close() + st.greet() + if err := st.fr.WriteSettings(settings...); err != nil { + st.t.Fatalf("WriteSettings: %v", err) + } + st.wantSettingsAck() + getSlash(st) + if err := <-errc; err != nil { + t.Error(err) + } + // Should not get a PUSH_PROMISE frame. + hf := st.wantHeaders() + if !hf.StreamEnded() { + t.Error("stream should end after headers") + } +} + +func TestServer_Push_RejectIfDisabled(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { + return fmt.Errorf("Push()=%v, want %v", got, want) + } + return nil + }, + Setting{SettingEnablePush, 0}) +} + +func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want { + return fmt.Errorf("Push()=%v, want %v", got, want) + } + return nil + }, + Setting{SettingMaxConcurrentStreams, 0}) +} + +func TestServer_Push_RejectWrongScheme(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil { + return errors.New("Push() should have failed (push target URL is http)") + } + return nil + }) +} + +func TestServer_Push_RejectMissingHost(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if err := p.Push("https:pushed", nil); err == nil { + return errors.New("Push() should have failed (push target URL missing host)") + } + return nil + }) +} + +func TestServer_Push_RejectRelativePath(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if err := p.Push("../test", nil); err == nil { + return errors.New("Push() should have failed (push target is a relative path)") + } + return nil + }) +} + +func TestServer_Push_RejectForbiddenMethod(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil { + return errors.New("Push() should have failed (cannot promise a POST)") + } + return nil + }) +} + +func TestServer_Push_RejectForbiddenHeader(t *testing.T) { + testServer_Push_RejectSingleRequest(t, + func(p http.Pusher, r *http.Request) error { + header := http.Header{ + "Content-Length": {"10"}, + "Content-Encoding": {"gzip"}, + "Trailer": {"Foo"}, + "Te": {"trailers"}, + "Host": {"test.com"}, + ":authority": {"test.com"}, + } + if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil { + return errors.New("Push() should have failed (forbidden headers)") + } + return nil + }) +} + +func TestServer_Push_StateTransitions(t *testing.T) { + const body = "foo" + + startedPromise := make(chan bool) + finishedPush := make(chan bool) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RequestURI() { + case "/": + if err := w.(http.Pusher).Push("/pushed", nil); err != nil { + t.Errorf("Push error: %v", err) + } + close(startedPromise) + // Don't finish this request until the push finishes so we don't + // nondeterministically interleave output frames with the push. + <-finishedPush + } + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(200) + io.WriteString(w, body) + }) + defer st.Close() + + st.greet() + if st.stream(2) != nil { + t.Fatal("stream 2 should be empty") + } + if got, want := st.streamState(2), stateIdle; got != want { + t.Fatalf("streamState(2)=%v, want %v", got, want) + } + getSlash(st) + <-startedPromise + if got, want := st.streamState(2), stateHalfClosedRemote; got != want { + t.Fatalf("streamState(2)=%v, want %v", got, want) + } + st.wantPushPromise() + st.wantHeaders() + if df := st.wantData(); !df.StreamEnded() { + t.Fatal("expected END_STREAM flag on DATA") + } + if got, want := st.streamState(2), stateClosed; got != want { + t.Fatalf("streamState(2)=%v, want %v", got, want) + } + close(finishedPush) +} + +func TestServer_Push_RejectAfterGoAway(t *testing.T) { + var readyOnce sync.Once + ready := make(chan struct{}) + errc := make(chan error, 2) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + select { + case <-ready: + case <-time.After(5 * time.Second): + errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed") + } + if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { + errc <- fmt.Errorf("Push()=%v, want %v", got, want) + } + errc <- nil + }) + defer st.Close() + st.greet() + getSlash(st) + + // Send GOAWAY and wait for it to be processed. + st.fr.WriteGoAway(1, ErrCodeNo, nil) + go func() { + for { + select { + case <-ready: + return + default: + } + st.sc.testHookCh <- func(loopNum int) { + if !st.sc.pushEnabled { + readyOnce.Do(func() { close(ready) }) + } + } + } + }() + if err := <-errc; err != nil { + t.Error(err) + } +} diff --git a/vendor/golang.org/x/net/http2/server_test.go b/vendor/golang.org/x/net/http2/server_test.go index 879e82135..2e6146b67 100644 --- a/vendor/golang.org/x/net/http2/server_test.go +++ b/vendor/golang.org/x/net/http2/server_test.go @@ -80,18 +80,19 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} tlsConfig := &tls.Config{ InsecureSkipVerify: true, - // The h2-14 is temporary, until curl is updated. (as used by unit tests - // in Docker) - NextProtos: []string{NextProtoTLS, "h2-14"}, + NextProtos: []string{NextProtoTLS}, } var onlyServer, quiet bool + h2server := new(Server) for _, opt := range opts { switch v := opt.(type) { case func(*tls.Config): v(tlsConfig) case func(*httptest.Server): v(ts) + case func(*Server): + v(h2server) case serverTesterOpt: switch v { case optOnlyServer: @@ -106,7 +107,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } } - ConfigureServer(ts.Config, &Server{}) + ConfigureServer(ts.Config, h2server) st := &serverTester{ t: t, @@ -253,6 +254,12 @@ func (st *serverTester) writeHeaders(p HeadersFrameParam) { } } +func (st *serverTester) writePriority(id uint32, p PriorityParam) { + if err := st.fr.WritePriority(id, p); err != nil { + st.t.Fatalf("Error writing PRIORITY: %v", err) + } +} + func (st *serverTester) encodeHeaderField(k, v string) { err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) if err != nil { @@ -278,37 +285,42 @@ func (st *serverTester) encodeHeaderRaw(headers ...string) []byte { // encodeHeader encodes headers and returns their HPACK bytes. headers // must contain an even number of key/value pairs. There may be // multiple pairs for keys (e.g. "cookie"). The :method, :path, and -// :scheme headers default to GET, / and https. +// :scheme headers default to GET, / and https. The :authority header +// defaults to st.ts.Listener.Addr(). func (st *serverTester) encodeHeader(headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } st.headerBuf.Reset() + defaultAuthority := st.ts.Listener.Addr().String() if len(headers) == 0 { // Fast path, mostly for benchmarks, so test code doesn't pollute // profiles when we're looking to improve server allocations. st.encodeHeaderField(":method", "GET") - st.encodeHeaderField(":path", "/") st.encodeHeaderField(":scheme", "https") + st.encodeHeaderField(":authority", defaultAuthority) + st.encodeHeaderField(":path", "/") return st.headerBuf.Bytes() } if len(headers) == 2 && headers[0] == ":method" { // Another fast path for benchmarks. st.encodeHeaderField(":method", headers[1]) - st.encodeHeaderField(":path", "/") st.encodeHeaderField(":scheme", "https") + st.encodeHeaderField(":authority", defaultAuthority) + st.encodeHeaderField(":path", "/") return st.headerBuf.Bytes() } pseudoCount := map[string]int{} - keys := []string{":method", ":path", ":scheme"} + keys := []string{":method", ":scheme", ":authority", ":path"} vals := map[string][]string{ - ":method": {"GET"}, - ":path": {"/"}, - ":scheme": {"https"}, + ":method": {"GET"}, + ":scheme": {"https"}, + ":authority": {defaultAuthority}, + ":path": {"/"}, } for len(headers) > 0 { k, v := headers[0], headers[1] @@ -503,7 +515,18 @@ func (st *serverTester) wantSettingsAck() { if !sf.Header().Flags.Has(FlagSettingsAck) { st.t.Fatal("Settings Frame didn't have ACK set") } +} +func (st *serverTester) wantPushPromise() *PushPromiseFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + ppf, ok := f.(*PushPromiseFrame) + if !ok { + st.t.Fatalf("Wanted PushPromise, received %T", ppf) + } + return ppf } func TestServer(t *testing.T) { @@ -758,7 +781,7 @@ func TestServer_Request_Get_Host(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers - BlockFragment: st.encodeHeader("host", host), + BlockFragment: st.encodeHeader(":authority", "", "host", host), EndStream: true, EndHeaders: true, }) @@ -937,7 +960,7 @@ func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) { func testRejectRequest(t *testing.T, send func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - t.Fatal("server request made it to handler; should've been rejected") + t.Error("server request made it to handler; should've been rejected") }) defer st.Close() @@ -946,6 +969,39 @@ func testRejectRequest(t *testing.T, send func(*serverTester)) { st.wantRSTStream(1, ErrCodeProtocol) } +func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + t.Error("server request made it to handler; should've been rejected") + }, optQuiet) + defer st.Close() + + st.greet() + send(st) + gf := st.wantGoAway() + if gf.ErrCode != ErrCodeProtocol { + t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol) + } +} + +// Section 5.1, on idle connections: "Receiving any frame other than +// HEADERS or PRIORITY on a stream in this state MUST be treated as a +// connection error (Section 5.4.1) of type PROTOCOL_ERROR." +func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) { + testRejectRequestWithProtocolError(t, func(st *serverTester) { + st.fr.WriteWindowUpdate(123, 456) + }) +} +func TestRejectFrameOnIdle_Data(t *testing.T) { + testRejectRequestWithProtocolError(t, func(st *serverTester) { + st.fr.WriteData(123, true, nil) + }) +} +func TestRejectFrameOnIdle_RSTStream(t *testing.T) { + testRejectRequestWithProtocolError(t, func(st *serverTester) { + st.fr.WriteRSTStream(123, ErrCodeCancel) + }) +} + func TestServer_Request_Connect(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -1445,6 +1501,36 @@ func TestServer_Rejects_Continuation0(t *testing.T) { }) } +// No PRIORITY on stream 0. +func TestServer_Rejects_Priority0(t *testing.T) { + testServerRejectsConn(t, func(st *serverTester) { + st.fr.AllowIllegalWrites = true + st.writePriority(0, PriorityParam{StreamDep: 1}) + }) +} + +// No HEADERS frame with a self-dependence. +func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { + testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { + st.fr.AllowIllegalWrites = true + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + Priority: PriorityParam{StreamDep: 1}, + }) + }) +} + +// No PRIORTY frame with a self-dependence. +func TestServer_Rejects_PrioritySelfDependence(t *testing.T) { + testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { + st.fr.AllowIllegalWrites = true + st.writePriority(1, PriorityParam{StreamDep: 1}) + }) +} + func TestServer_Rejects_PushPromise(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { pp := PushPromiseParam{ @@ -2840,6 +2926,12 @@ func BenchmarkServerPosts(b *testing.B) { const msg = "Hello, world" st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + // Consume the (empty) body from th peer before replying, otherwise + // the server will sometimes (depending on scheduling) send the peer a + // a RST_STREAM with the CANCEL error code. + if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil { + b.Errorf("Copy error; got %v, %v; want 0, nil", n, err) + } io.WriteString(w, msg) }) defer st.Close() @@ -3236,40 +3328,40 @@ func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte func TestCheckValidHTTP2Request(t *testing.T) { tests := []struct { - req *http.Request + h http.Header want error }{ { - req: &http.Request{Header: http.Header{"Te": {"trailers"}}}, + h: http.Header{"Te": {"trailers"}}, want: nil, }, { - req: &http.Request{Header: http.Header{"Te": {"trailers", "bogus"}}}, + h: http.Header{"Te": {"trailers", "bogus"}}, want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`), }, { - req: &http.Request{Header: http.Header{"Foo": {""}}}, + h: http.Header{"Foo": {""}}, want: nil, }, { - req: &http.Request{Header: http.Header{"Connection": {""}}}, + h: http.Header{"Connection": {""}}, want: errors.New(`request header "Connection" is not valid in HTTP/2`), }, { - req: &http.Request{Header: http.Header{"Proxy-Connection": {""}}}, + h: http.Header{"Proxy-Connection": {""}}, want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`), }, { - req: &http.Request{Header: http.Header{"Keep-Alive": {""}}}, + h: http.Header{"Keep-Alive": {""}}, want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`), }, { - req: &http.Request{Header: http.Header{"Upgrade": {""}}}, + h: http.Header{"Upgrade": {""}}, want: errors.New(`request header "Upgrade" is not valid in HTTP/2`), }, } for i, tt := range tests { - got := checkValidHTTP2Request(tt.req) + got := checkValidHTTP2RequestHeaders(tt.h) if !reflect.DeepEqual(got, tt.want) { t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want) } @@ -3366,3 +3458,118 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { } } + +func TestServerIdleTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, func(h2s *Server) { + h2s.IdleTimeout = 500 * time.Millisecond + }) + defer st.Close() + + st.greet() + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } +} + +func TestServerIdleTimeout_AfterRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + const timeout = 250 * time.Millisecond + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(timeout * 2) + }, func(h2s *Server) { + h2s.IdleTimeout = timeout + }) + defer st.Close() + + st.greet() + + // Send a request which takes twice the timeout. Verifies the + // idle timeout doesn't fire while we're in a request: + st.bodylessReq1() + st.wantHeaders() + + // But the idle timeout should be rearmed after the request + // is done: + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } +} + +// grpc-go closes the Request.Body currently with a Read. +// Verify that it doesn't race. +// See https://github.com/grpc/grpc-go/pull/938 +func TestRequestBodyReadCloseRace(t *testing.T) { + for i := 0; i < 100; i++ { + body := &requestBody{ + pipe: &pipe{ + b: new(bytes.Buffer), + }, + } + body.pipe.CloseWithError(io.EOF) + + done := make(chan bool, 1) + buf := make([]byte, 10) + go func() { + time.Sleep(1 * time.Millisecond) + body.Close() + done <- true + }() + body.Read(buf) + <-done + } +} + +func TestServerGracefulShutdown(t *testing.T) { + shutdownCh := make(chan struct{}) + defer func() { testh1ServerShutdownChan = nil }() + testh1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh } + + var st *serverTester + handlerDone := make(chan struct{}) + st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + close(shutdownCh) + + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } + if ga.LastStreamID != 1 { + t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) + } + + w.Header().Set("x-foo", "bar") + }) + defer st.Close() + + st.greet() + st.bodylessReq1() + + <-handlerDone + hf := st.wantHeaders() + goth := st.decodeHeader(hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"x-foo", "bar"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", "0"}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + + n, err := st.cc.Read([]byte{0}) + if n != 0 || err == nil { + t.Errorf("Read = %v, %v; want 0, non-nil", n, err) + } +} diff --git a/vendor/golang.org/x/net/http2/transport.go b/vendor/golang.org/x/net/http2/transport.go index 42c73bd1e..8f5f84412 100644 --- a/vendor/golang.org/x/net/http2/transport.go +++ b/vendor/golang.org/x/net/http2/transport.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "compress/gzip" + "crypto/rand" "crypto/tls" "errors" "fmt" @@ -150,6 +151,9 @@ type ClientConn struct { readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes flow flow // our conn-level flow control quota (cs.flow is per stream) @@ -160,6 +164,7 @@ type ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*clientStream // client-initiated nextStreamID uint32 + pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader fr *Framer @@ -194,6 +199,7 @@ type clientStream struct { bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -221,15 +227,26 @@ func (cs *clientStream) awaitRequestCancel(req *http.Request) { } select { case <-req.Cancel: + cs.cancelStream() cs.bufPipe.CloseWithError(errRequestCanceled) - cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) case <-ctx.Done(): + cs.cancelStream() cs.bufPipe.CloseWithError(ctx.Err()) - cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) case <-cs.done: } } +func (cs *clientStream) cancelStream() { + cs.cc.mu.Lock() + didReset := cs.didReset + cs.didReset = true + cs.cc.mu.Unlock() + + if !didReset { + cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) + } +} + // checkResetOrDone reports any error sent in a RST_STREAM frame by the // server, or errStreamClosed if the stream is complete. func (cs *clientStream) checkResetOrDone() error { @@ -431,6 +448,11 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro streams: make(map[uint32]*clientStream), singleUse: singleUse, wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -508,6 +530,16 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool { cc.nextStreamID < math.MaxInt32 } +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + func (cc *ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 { @@ -604,51 +636,37 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration { // Certain headers are special-cased as okay but not transmitted later. func checkConnHeaders(req *http.Request) error { if v := req.Header.Get("Upgrade"); v != "" { - return errors.New("http2: invalid Upgrade request header") + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } - if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 { - return errors.New("http2: invalid Transfer-Encoding request header") + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 { - return errors.New("http2: invalid Connection request header") + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) } return nil } -func bodyAndLength(req *http.Request) (body io.Reader, contentLen int64) { - body = req.Body - if body == nil { - return nil, 0 +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil { + return 0 } if req.ContentLength != 0 { - return req.Body, req.ContentLength - } - - // We have a body but a zero content length. Test to see if - // it's actually zero or just unset. - var buf [1]byte - n, rerr := body.Read(buf[:]) - if rerr != nil && rerr != io.EOF { - return errorReader{rerr}, -1 + return req.ContentLength } - if n == 1 { - // Oh, guess there is data in this Body Reader after all. - // The ContentLength field just wasn't set. - // Stitch the Body back together again, re-attaching our - // consumed byte. - if rerr == io.EOF { - return bytes.NewReader(buf[:]), 1 - } - return io.MultiReader(bytes.NewReader(buf[:]), body), -1 - } - // Body is actually zero bytes. - return nil, 0 + return -1 } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if err := checkConnHeaders(req); err != nil { return nil, err } + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } trailers, err := commaSeparatedTrailers(req) if err != nil { @@ -663,8 +681,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return nil, errClientConnUnusable } - body, contentLen := bodyAndLength(req) + body := req.Body hasBody := body != nil + contentLen := actualContentLength(req) // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? var requestedGzip bool @@ -1046,7 +1065,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { cc.writeHeader(":path", path) - cc.writeHeader(":scheme", "https") + cc.writeHeader(":scheme", req.URL.Scheme) } if trailers != "" { cc.writeHeader("trailer", trailers) @@ -1173,6 +1192,9 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream { if andRemove && cs != nil && !cc.closed { cc.lastActive = time.Now() delete(cc.streams, id) + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + } close(cs.done) cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } @@ -1229,6 +1251,10 @@ func (rl *clientConnReadLoop) cleanup() { defer cc.t.connPool().MarkDead(cc) defer close(cc.readerDone) + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + // Close any response bodies if the server closes prematurely. // TODO: also do this if we've written the headers but not // gotten a response yet. @@ -1652,9 +1678,10 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc.bw.Flush() cc.wmu.Unlock() } + didReset := cs.didReset cc.mu.Unlock() - if len(data) > 0 { + if len(data) > 0 && !didReset { if _, err := cs.bufPipe.Write(data); err != nil { rl.endStreamError(cs, err) return err @@ -1815,10 +1842,56 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { return nil } +// Ping sends a PING frame to the server and waits for the ack. +// Public implementation is in go17.go and not_go17.go +func (cc *ClientConn) ping(ctx contextContext) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + // check for dup before insert + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + cc.wmu.Lock() + if err := cc.fr.WritePing(false, p); err != nil { + cc.wmu.Unlock() + return err + } + if err := cc.bw.Flush(); err != nil { + cc.wmu.Unlock() + return err + } + cc.wmu.Unlock() + select { + case <-c: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + // connection closed + return cc.readerErr + } +} + func (rl *clientConnReadLoop) processPing(f *PingFrame) error { if f.IsAck() { - // 6.7 PING: " An endpoint MUST NOT respond to PING frames - // containing this flag." + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + // If ack, notify listener if any + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } return nil } cc := rl.cc diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go index 96d0a0867..f9287e575 100644 --- a/vendor/golang.org/x/net/http2/transport_test.go +++ b/vendor/golang.org/x/net/http2/transport_test.go @@ -39,6 +39,13 @@ var ( var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} +type testContext struct{} + +func (testContext) Done() <-chan struct{} { return make(chan struct{}) } +func (testContext) Err() error { panic("should not be called") } +func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false } +func (testContext) Value(key interface{}) interface{} { return nil } + func TestTransportExternal(t *testing.T) { if !*extNet { t.Skip("skipping external network test") @@ -52,6 +59,16 @@ func TestTransportExternal(t *testing.T) { res.Write(os.Stdout) } +type fakeTLSConn struct { + net.Conn +} + +func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { + return tls.ConnectionState{ + Version: tls.VersionTLS12, + } +} + func startH2cServer(t *testing.T) net.Listener { h2Server := &Server{} l := newLocalListener(t) @@ -61,8 +78,8 @@ func startH2cServer(t *testing.T) net.Listener { t.Error(err) return } - h2Server.ServeConn(conn, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %v", r.URL.Path) + h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) })}) }() return l @@ -92,7 +109,7 @@ func TestTransportH2c(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := string(body), "Hello, /foobar"; got != want { + if got, want := string(body), "Hello, /foobar, http: true"; got != want { t.Fatalf("response got %v, want %v", got, want) } } @@ -374,6 +391,40 @@ func randString(n int) string { return string(b) } +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + } + for i, tt := range tests { + got := actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + func TestTransportBody(t *testing.T) { bodyTests := []struct { body string @@ -381,8 +432,6 @@ func TestTransportBody(t *testing.T) { }{ {body: "some message"}, {body: "some message", noContentLen: true}, - {body: ""}, - {body: "", noContentLen: true}, {body: strings.Repeat("a", 1<<20), noContentLen: true}, {body: strings.Repeat("a", 1<<20)}, {body: randString(16<<10 - 1)}, @@ -1690,12 +1739,12 @@ func TestTransportRejectsConnHeaders(t *testing.T) { { key: "Upgrade", value: []string{"anything"}, - want: "ERROR: http2: invalid Upgrade request header", + want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]", }, { key: "Connection", value: []string{"foo"}, - want: "ERROR: http2: invalid Connection request header", + want: "ERROR: http2: invalid Connection request header: [\"foo\"]", }, { key: "Connection", @@ -1705,7 +1754,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { { key: "Connection", value: []string{"close", "something-else"}, - want: "ERROR: http2: invalid Connection request header", + want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]", }, { key: "Connection", @@ -1725,7 +1774,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { { key: "Transfer-Encoding", value: []string{"foo"}, - want: "ERROR: http2: invalid Transfer-Encoding request header", + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]", }, { key: "Transfer-Encoding", @@ -1735,7 +1784,7 @@ func TestTransportRejectsConnHeaders(t *testing.T) { { key: "Transfer-Encoding", value: []string{"chunked", "other"}, - want: "ERROR: http2: invalid Transfer-Encoding request header", + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]", }, { key: "Content-Length", @@ -1898,8 +1947,17 @@ func TestTransportNewTLSConfig(t *testing.T) { }, } for i, tt := range tests { + // Ignore the session ticket keys part, which ends up populating + // unexported fields in the Config: + if tt.conf != nil { + tt.conf.SessionTicketsDisabled = true + } + tr := &Transport{TLSClientConfig: tt.conf} got := tr.newTLSConfig(tt.host) + + got.SessionTicketsDisabled = false + if !reflect.DeepEqual(got, tt.want) { t.Errorf("%d. got %#v; want %#v", i, got, tt.want) } @@ -2618,3 +2676,72 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { t.Errorf("Body = %q; want %q", slurp, body) } } + +func TestClientConnPing(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) + defer st.Close() + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + if err = cc.Ping(testContext{}); err != nil { + t.Fatal(err) + } +} + +// Issue 16974: if the server sent a DATA frame after the user +// canceled the Transport's Request, the Transport previously wrote to a +// closed pipe, got an error, and ended up closing the whole TCP +// connection. +func TestTransportCancelDataResponseRace(t *testing.T) { + cancel := make(chan struct{}) + clientGotError := make(chan bool, 1) + + const msg = "Hello." + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/hello") { + time.Sleep(50 * time.Millisecond) + io.WriteString(w, msg) + return + } + for i := 0; i < 50; i++ { + io.WriteString(w, "Some data.") + w.(http.Flusher).Flush() + if i == 2 { + close(cancel) + <-clientGotError + } + time.Sleep(10 * time.Millisecond) + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Cancel = cancel + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { + t.Fatal("unexpected success") + } + clientGotError <- true + + res, err = c.Get(st.ts.URL + "/hello") + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != msg { + t.Errorf("Got = %q; want %q", slurp, msg) + } +} diff --git a/vendor/golang.org/x/net/http2/write.go b/vendor/golang.org/x/net/http2/write.go index 27ef0dd4d..1c135fdf7 100644 --- a/vendor/golang.org/x/net/http2/write.go +++ b/vendor/golang.org/x/net/http2/write.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "net/http" + "net/url" "time" "golang.org/x/net/http2/hpack" @@ -18,6 +19,11 @@ import ( // writeFramer is implemented by any type that is used to write frames. type writeFramer interface { writeFrame(writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool } // writeContext is the interface needed by the various frame writer @@ -62,8 +68,16 @@ func (flushFrameWriter) writeFrame(ctx writeContext) error { return ctx.Flush() } +func (flushFrameWriter) staysWithinBuffer(max int) bool { return false } + type writeSettings []Setting +func (s writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return frameHeaderLen+settingSize*len(s) <= max + +} + func (s writeSettings) writeFrame(ctx writeContext) error { return ctx.Framer().WriteSettings([]Setting(s)...) } @@ -83,6 +97,8 @@ func (p *writeGoAway) writeFrame(ctx writeContext) error { return err } +func (*writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes + type writeData struct { streamID uint32 p []byte @@ -97,6 +113,10 @@ func (w *writeData) writeFrame(ctx writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } +func (w *writeData) staysWithinBuffer(max int) bool { + return frameHeaderLen+len(w.p) <= max +} + // handlerPanicRST is the message sent from handler goroutines when // the handler panics. type handlerPanicRST struct { @@ -107,22 +127,57 @@ func (hp handlerPanicRST) writeFrame(ctx writeContext) error { return ctx.Framer().WriteRSTStream(hp.StreamID, ErrCodeInternal) } +func (hp handlerPanicRST) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } + func (se StreamError) writeFrame(ctx writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } +func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } + type writePingAck struct{ pf *PingFrame } func (w writePingAck) writeFrame(ctx writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } +func (w writePingAck) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.pf.Data) <= max } + type writeSettingsAck struct{} func (writeSettingsAck) writeFrame(ctx writeContext) error { return ctx.Framer().WriteSettingsAck() } +func (writeSettingsAck) staysWithinBuffer(max int) bool { return frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func splitHeaderBlock(ctx writeContext, headerBlock []byte, fn func(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. type writeResHeaders struct { @@ -144,6 +199,17 @@ func encKV(enc *hpack.Encoder, k, v string) { enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } +func (w *writeResHeaders) staysWithinBuffer(max int) bool { + // TODO: this is a common one. It'd be nice to return true + // here and get into the fast path if we could be clever and + // calculate the size fast enough, or at least a conservative + // uppper bound that usually fires. (Maybe if w.h and + // w.trailers are nil, so we don't need to enumerate it.) + // Otherwise I'm afraid that just calculating the length to + // answer this question would be slower than the ~2µs benefit. + return false +} + func (w *writeResHeaders) writeFrame(ctx writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() @@ -169,39 +235,69 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error { panic("unexpected empty hpack") } - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 + return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - endHeaders := len(headerBlock) == 0 - var err error - if first { - first = false - err = ctx.Framer().WriteHeaders(HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: endHeaders, - }) - } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) - } - if err != nil { - return err - } +func (w *writeResHeaders) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h http.Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *writePushPromise) staysWithinBuffer(max int) bool { + // TODO: see writeResHeaders.staysWithinBuffer + return false +} + +func (w *writePushPromise) writeFrame(ctx writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + encKV(enc, ":method", w.method) + encKV(enc, ":scheme", w.url.Scheme) + encKV(enc, ":authority", w.url.Host) + encKV(enc, ":path", w.url.RequestURI()) + encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *writePushPromise) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } - return nil } type write100ContinueHeadersFrame struct { @@ -220,15 +316,24 @@ func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error { }) } +func (w write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + // Sloppy but conservative: + return 9+2*(len(":status")+len("100")) <= max +} + type writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } +func (wu writeWindowUpdate) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } + func (wu writeWindowUpdate) writeFrame(ctx writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only only if k is in keys. func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { if keys == nil { sorter := sorterPool.Get().(*sorter) diff --git a/vendor/golang.org/x/net/http2/writesched.go b/vendor/golang.org/x/net/http2/writesched.go index c24316ce7..caa77c7cb 100644 --- a/vendor/golang.org/x/net/http2/writesched.go +++ b/vendor/golang.org/x/net/http2/writesched.go @@ -6,14 +6,53 @@ package http2 import "fmt" -// frameWriteMsg is a request to write a frame. -type frameWriteMsg struct { +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. + Pop() (wr FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type FrameWriteRequest struct { // write is the interface value that does the writing, once the - // writeScheduler (below) has decided to select this frame - // to write. The write functions are all defined in write.go. + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. write writeFramer - stream *stream // used for prioritization. nil for non-stream frames. + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -21,263 +60,169 @@ type frameWriteMsg struct { done chan error } -// for debugging only: -func (wm frameWriteMsg) String() string { - var streamID uint32 - if wm.stream != nil { - streamID = wm.stream.id - } - var des string - if s, ok := wm.write.(fmt.Stringer); ok { - des = s.String() - } else { - des = fmt.Sprintf("%T", wm.write) - } - return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) -} - -// writeScheduler tracks pending frames to write, priorities, and decides -// the next one to use. It is not thread-safe. -type writeScheduler struct { - // zero are frames not associated with a specific stream. - // They're sent before any stream-specific freams. - zero writeQueue - - // maxFrameSize is the maximum size of a DATA frame - // we'll write. Must be non-zero and between 16K-16M. - maxFrameSize uint32 - - // sq contains the stream-specific queues, keyed by stream ID. - // when a stream is idle, it's deleted from the map. - sq map[uint32]*writeQueue - - // canSend is a slice of memory that's reused between frame - // scheduling decisions to hold the list of writeQueues (from sq) - // which have enough flow control data to send. After canSend is - // built, the best is selected. - canSend []*writeQueue - - // pool of empty queues for reuse. - queuePool []*writeQueue -} - -func (ws *writeScheduler) putEmptyQueue(q *writeQueue) { - if len(q.s) != 0 { - panic("queue must be empty") - } - ws.queuePool = append(ws.queuePool, q) -} - -func (ws *writeScheduler) getEmptyQueue() *writeQueue { - ln := len(ws.queuePool) - if ln == 0 { - return new(writeQueue) - } - q := ws.queuePool[ln-1] - ws.queuePool = ws.queuePool[:ln-1] - return q -} - -func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } - -func (ws *writeScheduler) add(wm frameWriteMsg) { - st := wm.stream - if st == nil { - ws.zero.push(wm) - } else { - ws.streamQueue(st.id).push(wm) - } -} - -func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue { - if q, ok := ws.sq[streamID]; ok { - return q - } - if ws.sq == nil { - ws.sq = make(map[uint32]*writeQueue) - } - q := ws.getEmptyQueue() - ws.sq[streamID] = q - return q -} - -// take returns the most important frame to write and removes it from the scheduler. -// It is illegal to call this if the scheduler is empty or if there are no connection-level -// flow control bytes available. -func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) { - if ws.maxFrameSize == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") - } - - // If there any frames not associated with streams, prefer those first. - // These are usually SETTINGS, etc. - if !ws.zero.empty() { - return ws.zero.shift(), true - } - if len(ws.sq) == 0 { - return - } - - // Next, prioritize frames on streams that aren't DATA frames (no cost). - for id, q := range ws.sq { - if q.firstIsNoCost() { - return ws.takeFrom(id, q) +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(StreamError); ok { + // (*serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. + return se.StreamID } - } - - // Now, all that remains are DATA frames with non-zero bytes to - // send. So pick the best one. - if len(ws.canSend) != 0 { - panic("should be empty") - } - for _, q := range ws.sq { - if n := ws.streamWritableBytes(q); n > 0 { - ws.canSend = append(ws.canSend, q) - } - } - if len(ws.canSend) == 0 { - return - } - defer ws.zeroCanSend() - - // TODO: find the best queue - q := ws.canSend[0] - - return ws.takeFrom(q.streamID(), q) -} - -// zeroCanSend is defered from take. -func (ws *writeScheduler) zeroCanSend() { - for i := range ws.canSend { - ws.canSend[i] = nil - } - ws.canSend = ws.canSend[:0] -} - -// streamWritableBytes returns the number of DATA bytes we could write -// from the given queue's stream, if this stream/queue were -// selected. It is an error to call this if q's head isn't a -// *writeData. -func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 { - wm := q.head() - ret := wm.stream.flow.available() // max we can write - if ret == 0 { return 0 } - if int32(ws.maxFrameSize) < ret { - ret = int32(ws.maxFrameSize) - } - if ret == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") - } - wd := wm.write.(*writeData) - if len(wd.p) < int(ret) { - ret = int32(len(wd.p)) - } - return ret -} - -func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) { - wm = q.head() - // If the first item in this queue costs flow control tokens - // and we don't have enough, write as much as we can. - if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 { - allowed := wm.stream.flow.available() // max we can write - if allowed == 0 { - // No quota available. Caller can try the next stream. - return frameWriteMsg{}, false + return wr.stream.id +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr FrameWriteRequest) Consume(n int32) (FrameWriteRequest, FrameWriteRequest, int) { + var empty FrameWriteRequest + + // Non-DATA frames are always consumed whole. + wd, ok := wr.write.(*writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + // Might need to split after applying limits. + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := FrameWriteRequest{ + stream: wr.stream, + write: &writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + // Even if the original had endStream set, there + // are bytes remaining because len(wd.p) > allowed, + // so we know endStream is false. + endStream: false, + }, + // Our caller is blocking on the final DATA frame, not + // this intermediate frame, so no need to wait. + done: nil, } - if int32(ws.maxFrameSize) < allowed { - allowed = int32(ws.maxFrameSize) + rest := FrameWriteRequest{ + stream: wr.stream, + write: &writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, } - // TODO: further restrict the allowed size, because even if - // the peer says it's okay to write 16MB data frames, we might - // want to write smaller ones to properly weight competing - // streams' priorities. - - if len(wd.p) > int(allowed) { - wm.stream.flow.take(allowed) - chunk := wd.p[:allowed] - wd.p = wd.p[allowed:] - // Make up a new write message of a valid size, rather - // than shifting one off the queue. - return frameWriteMsg{ - stream: wm.stream, - write: &writeData{ - streamID: wd.streamID, - p: chunk, - // even if the original had endStream set, there - // arebytes remaining because len(wd.p) > allowed, - // so we know endStream is false: - endStream: false, - }, - // our caller is blocking on the final DATA frame, not - // these intermediates, so no need to wait: - done: nil, - }, true - } - wm.stream.flow.take(int32(len(wd.p))) + return consumed, rest, 2 } - q.shift() - if q.empty() { - ws.putEmptyQueue(q) - delete(ws.sq, id) - } - return wm, true + // The frame is consumed whole. + // NB: This cast cannot overflow because allowed is <= math.MaxInt32. + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 } -func (ws *writeScheduler) forgetStream(id uint32) { - q, ok := ws.sq[id] - if !ok { - return - } - delete(ws.sq, id) - - // But keep it for others later. - for i := range q.s { - q.s[i] = frameWriteMsg{} +// String is for debugging only. +func (wr FrameWriteRequest) String() string { + var des string + if s, ok := wr.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wr.write) } - q.s = q.s[:0] - ws.putEmptyQueue(q) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) } +// writeQueue is used by implementations of WriteScheduler. type writeQueue struct { - s []frameWriteMsg + s []FrameWriteRequest } -// streamID returns the stream ID for a non-empty stream-specific queue. -func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id } - func (q *writeQueue) empty() bool { return len(q.s) == 0 } -func (q *writeQueue) push(wm frameWriteMsg) { - q.s = append(q.s, wm) +func (q *writeQueue) push(wr FrameWriteRequest) { + q.s = append(q.s, wr) } -// head returns the next item that would be removed by shift. -func (q *writeQueue) head() frameWriteMsg { +func (q *writeQueue) shift() FrameWriteRequest { if len(q.s) == 0 { panic("invalid use of queue") } - return q.s[0] + wr := q.s[0] + // TODO: less copy-happy queue. + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr } -func (q *writeQueue) shift() frameWriteMsg { +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) { if len(q.s) == 0 { - panic("invalid use of queue") + return FrameWriteRequest{}, false } - wm := q.s[0] - // TODO: less copy-happy queue. - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = frameWriteMsg{} - q.s = q.s[:len(q.s)-1] - return wm + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true +} + +type writeQueuePool []*writeQueue + +// put inserts an unused writeQueue into the pool. +func (p *writeQueuePool) put(q *writeQueue) { + for i := range q.s { + q.s[i] = FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) } -func (q *writeQueue) firstIsNoCost() bool { - if df, ok := q.s[0].write.(*writeData); ok { - return len(df.p) == 0 +// get returns an empty writeQueue. +func (p *writeQueuePool) get() *writeQueue { + ln := len(*p) + if ln == 0 { + return new(writeQueue) } - return true + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] + return q } diff --git a/vendor/golang.org/x/net/http2/writesched_priority.go b/vendor/golang.org/x/net/http2/writesched_priority.go new file mode 100644 index 000000000..01132721b --- /dev/null +++ b/vendor/golang.org/x/net/http2/writesched_priority.go @@ -0,0 +1,452 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "fmt" + "math" + "sort" +) + +// RFC 7540, Section 5.3.5: the default weight is 16. +const priorityDefaultWeight = 15 // 16 = 15 + 1 + +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// If cfg is nil, default options are used. +func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler { + if cfg == nil { + // For justification of these defaults, see: + // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY + cfg = &PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &priorityWriteScheduler{ + nodes: make(map[uint32]*priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 + } else { + ws.writeThrottleLimit = math.MaxInt32 + } + return ws +} + +type priorityNodeState int + +const ( + priorityNodeOpen priorityNodeState = iota + priorityNodeClosed + priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type priorityNode struct { + q writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *priorityNode + kids *priorityNode // start of the kids list + prev, next *priorityNode // doubly-linked list of siblings +} + +func (n *priorityNode) setParent(parent *priorityNode) { + if n == parent { + panic("setParent to self") + } + if n.parent == parent { + return + } + // Unlink from current parent. + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + // Link to new parent. + // If parent=nil, remove n from the tree. + // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n + } +} + +func (n *priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b + } +} + +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this funcion returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true + } + if n.kids == nil { + return false + } + + // Don't consider the root "open" when updating openParent since + // we can't send data frames on the root stream (only control frames). + if n.id != 0 { + openParent = openParent || (n.state == priorityNodeOpen) + } + + // Common case: only one kid or all kids have the same weight. + // Some clients don't use weights; other clients (like web browsers) + // use mostly-linear priority trees. + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break + } + } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } + + // Uncommon case: sort the child nodes. We remove the kids from the parent, + // then re-insert after sorting so we can reuse tmp for future sort calls. + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) + } + sort.Sort(sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false +} + +type sortPriorityNodeSiblings []*priorityNode + +func (z sortPriorityNodeSiblings) Len() int { return len(z) } +func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } +func (z sortPriorityNodeSiblings) Less(i, k int) bool { + // Prefer the subtree that has sent fewer bytes relative to its weight. + // See sections 5.3.2 and 5.3.4. + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk + } + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} + +type priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root priorityNode + + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*priorityNode + + // pool of empty queues for reuse. + queuePool writeQueuePool +} + +func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { + // The stream may be currently idle but cannot be opened or closed. + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = priorityNodeOpen + return + } + + // RFC 7540, Section 5.3.5: + // "All streams are initially assigned a non-exclusive dependency on stream 0x0. + // Pushed streams initially depend on their associated stream. In both cases, + // streams are assigned a default weight of 16." + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: priorityDefaultWeight, + state: priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID + } +} + +func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") + } + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) + } + if ws.nodes[streamID].state != priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) + } + + n := ws.nodes[streamID] + n.state = priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) + } +} + +func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } + + // If streamID does not exist, there are two cases: + // - A closed stream that has been removed (this will have ID <= maxID) + // - An idle stream that is being used for "grouping" (this will have ID > maxID) + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return + } + ws.maxID = streamID + n = &priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: priorityDefaultWeight, + state: priorityNodeIdle, + } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } + + // Section 5.3.1: A dependency on a stream that is not currently in the tree + // results in that stream being given a default priority (Section 5.3.5). + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = priorityDefaultWeight + return + } + + // Ignore if the client tries to make a node its own parent. + if n == parent { + return + } + + // Section 5.3.3: + // "If a stream is made dependent on one of its own dependencies, the + // formerly dependent stream is first moved to be dependent on the + // reprioritized stream's previous parent. The moved dependency retains + // its weight." + // + // That is: if parent depends on n, move parent to depend on n.parent. + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } + + // Section 5.3.3: The exclusive flag causes the stream to become the sole + // dependency of its parent stream, causing other dependencies to become + // dependent on the exclusive stream. + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next + } + } + + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { + var n *priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + // id is an idle or closed stream. wr should not be a HEADERS or + // DATA frame. However, wr can be a RST_STREAM. In this case, we + // push wr onto the root, rather than creating a new priorityNode, + // since RST_STREAM is tiny and the stream's priority is unknown + // anyway. See issue #17919. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } + } + n.q.push(wr) +} + +func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + // If B depends on A and B continuously has data available but A + // does not, gradually increase the throttling limit to allow B to + // steal more and more bandwidth from A. + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) { + if maxSize == 0 { + return + } + if len(*list) == maxSize { + // Remove the oldest node, then shift left. + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] + } + *list = append(*list, n) +} + +func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) +} diff --git a/vendor/golang.org/x/net/http2/writesched_priority_test.go b/vendor/golang.org/x/net/http2/writesched_priority_test.go new file mode 100644 index 000000000..2b232043c --- /dev/null +++ b/vendor/golang.org/x/net/http2/writesched_priority_test.go @@ -0,0 +1,541 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "bytes" + "fmt" + "sort" + "testing" +) + +func defaultPriorityWriteScheduler() *priorityWriteScheduler { + return NewPriorityWriteScheduler(nil).(*priorityWriteScheduler) +} + +func checkPriorityWellFormed(ws *priorityWriteScheduler) error { + for id, n := range ws.nodes { + if id != n.id { + return fmt.Errorf("bad ws.nodes: ws.nodes[%d] = %d", id, n.id) + } + if n.parent == nil { + if n.next != nil || n.prev != nil { + return fmt.Errorf("bad node %d: nil parent but prev/next not nil", id) + } + continue + } + found := false + for k := n.parent.kids; k != nil; k = k.next { + if k.id == id { + found = true + break + } + } + if !found { + return fmt.Errorf("bad node %d: not found in parent %d kids list", id, n.parent.id) + } + } + return nil +} + +func fmtTree(ws *priorityWriteScheduler, fmtNode func(*priorityNode) string) string { + var ids []int + for _, n := range ws.nodes { + ids = append(ids, int(n.id)) + } + sort.Ints(ids) + + var buf bytes.Buffer + for _, id := range ids { + if buf.Len() != 0 { + buf.WriteString(" ") + } + if id == 0 { + buf.WriteString(fmtNode(&ws.root)) + } else { + buf.WriteString(fmtNode(ws.nodes[uint32(id)])) + } + } + return buf.String() +} + +func fmtNodeParentSkipRoot(n *priorityNode) string { + switch { + case n.id == 0: + return "" + case n.parent == nil: + return fmt.Sprintf("%d{parent:nil}", n.id) + default: + return fmt.Sprintf("%d{parent:%d}", n.id, n.parent.id) + } +} + +func fmtNodeWeightParentSkipRoot(n *priorityNode) string { + switch { + case n.id == 0: + return "" + case n.parent == nil: + return fmt.Sprintf("%d{weight:%d,parent:nil}", n.id, n.weight) + default: + return fmt.Sprintf("%d{weight:%d,parent:%d}", n.id, n.weight, n.parent.id) + } +} + +func TestPriorityTwoStreams(t *testing.T) { + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{}) + + want := "1{weight:15,parent:0} 2{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After open\ngot %q\nwant %q", got, want) + } + + // Move 1's parent to 2. + ws.AdjustStream(1, PriorityParam{ + StreamDep: 2, + Weight: 32, + Exclusive: false, + }) + want = "1{weight:32,parent:2} 2{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityAdjustExclusiveZero(t *testing.T) { + // 1, 2, and 3 are all children of the 0 stream. + // Exclusive reprioritization to any of the streams should bring + // the rest of the streams under the reprioritized stream. + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{}) + ws.OpenStream(3, OpenStreamOptions{}) + + want := "1{weight:15,parent:0} 2{weight:15,parent:0} 3{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After open\ngot %q\nwant %q", got, want) + } + + ws.AdjustStream(2, PriorityParam{ + StreamDep: 0, + Weight: 20, + Exclusive: true, + }) + want = "1{weight:15,parent:2} 2{weight:20,parent:0} 3{weight:15,parent:2}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityAdjustOwnParent(t *testing.T) { + // Assigning a node as its own parent should have no effect. + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{}) + ws.AdjustStream(2, PriorityParam{ + StreamDep: 2, + Weight: 20, + Exclusive: true, + }) + want := "1{weight:15,parent:0} 2{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityClosedStreams(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxClosedNodesInTree: 2}).(*priorityWriteScheduler) + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 2}) + ws.OpenStream(4, OpenStreamOptions{PusherID: 3}) + + // Close the first three streams. We lose 1, but keep 2 and 3. + ws.CloseStream(1) + ws.CloseStream(2) + ws.CloseStream(3) + + want := "2{weight:15,parent:0} 3{weight:15,parent:2} 4{weight:15,parent:3}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After close\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } + + // Adding a stream as an exclusive child of 1 gives it default + // priorities, since 1 is gone. + ws.OpenStream(5, OpenStreamOptions{}) + ws.AdjustStream(5, PriorityParam{StreamDep: 1, Weight: 15, Exclusive: true}) + + // Adding a stream as an exclusive child of 2 should work, since 2 is not gone. + ws.OpenStream(6, OpenStreamOptions{}) + ws.AdjustStream(6, PriorityParam{StreamDep: 2, Weight: 15, Exclusive: true}) + + want = "2{weight:15,parent:0} 3{weight:15,parent:6} 4{weight:15,parent:3} 5{weight:15,parent:0} 6{weight:15,parent:2}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After add streams\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityClosedStreamsDisabled(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteScheduler) + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 2}) + + // Close the first two streams. We keep only 3. + ws.CloseStream(1) + ws.CloseStream(2) + + want := "3{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After close\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityIdleStreams(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxIdleNodesInTree: 2}).(*priorityWriteScheduler) + ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 15}) // idle + ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 15}) // idle + ws.AdjustStream(3, PriorityParam{StreamDep: 2, Weight: 20}) // idle + ws.OpenStream(4, OpenStreamOptions{}) + ws.OpenStream(5, OpenStreamOptions{}) + ws.OpenStream(6, OpenStreamOptions{}) + ws.AdjustStream(4, PriorityParam{StreamDep: 1, Weight: 15}) + ws.AdjustStream(5, PriorityParam{StreamDep: 2, Weight: 15}) + ws.AdjustStream(6, PriorityParam{StreamDep: 3, Weight: 15}) + + want := "2{weight:15,parent:0} 3{weight:20,parent:2} 4{weight:15,parent:0} 5{weight:15,parent:2} 6{weight:15,parent:3}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After open\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPriorityIdleStreamsDisabled(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteScheduler) + ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 15}) // idle + ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 15}) // idle + ws.AdjustStream(3, PriorityParam{StreamDep: 2, Weight: 20}) // idle + ws.OpenStream(4, OpenStreamOptions{}) + + want := "4{weight:15,parent:0}" + if got := fmtTree(ws, fmtNodeWeightParentSkipRoot); got != want { + t.Errorf("After open\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPrioritySection531NonExclusive(t *testing.T) { + // Example from RFC 7540 Section 5.3.1. + // A,B,C,D = 1,2,3,4 + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(4, OpenStreamOptions{}) + ws.AdjustStream(4, PriorityParam{ + StreamDep: 1, + Weight: 15, + Exclusive: false, + }) + want := "1{parent:0} 2{parent:1} 3{parent:1} 4{parent:1}" + if got := fmtTree(ws, fmtNodeParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPrioritySection531Exclusive(t *testing.T) { + // Example from RFC 7540 Section 5.3.1. + // A,B,C,D = 1,2,3,4 + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(4, OpenStreamOptions{}) + ws.AdjustStream(4, PriorityParam{ + StreamDep: 1, + Weight: 15, + Exclusive: true, + }) + want := "1{parent:0} 2{parent:4} 3{parent:4} 4{parent:1}" + if got := fmtTree(ws, fmtNodeParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func makeSection533Tree() *priorityWriteScheduler { + // Initial tree from RFC 7540 Section 5.3.3. + // A,B,C,D,E,F = 1,2,3,4,5,6 + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(4, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(5, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(6, OpenStreamOptions{PusherID: 4}) + return ws +} + +func TestPrioritySection533NonExclusive(t *testing.T) { + // Example from RFC 7540 Section 5.3.3. + // A,B,C,D,E,F = 1,2,3,4,5,6 + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(4, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(5, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(6, OpenStreamOptions{PusherID: 4}) + ws.AdjustStream(1, PriorityParam{ + StreamDep: 4, + Weight: 15, + Exclusive: false, + }) + want := "1{parent:4} 2{parent:1} 3{parent:1} 4{parent:0} 5{parent:3} 6{parent:4}" + if got := fmtTree(ws, fmtNodeParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func TestPrioritySection533Exclusive(t *testing.T) { + // Example from RFC 7540 Section 5.3.3. + // A,B,C,D,E,F = 1,2,3,4,5,6 + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(4, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(5, OpenStreamOptions{PusherID: 3}) + ws.OpenStream(6, OpenStreamOptions{PusherID: 4}) + ws.AdjustStream(1, PriorityParam{ + StreamDep: 4, + Weight: 15, + Exclusive: true, + }) + want := "1{parent:4} 2{parent:1} 3{parent:1} 4{parent:0} 5{parent:3} 6{parent:1}" + if got := fmtTree(ws, fmtNodeParentSkipRoot); got != want { + t.Errorf("After adjust\ngot %q\nwant %q", got, want) + } + if err := checkPriorityWellFormed(ws); err != nil { + t.Error(err) + } +} + +func checkPopAll(ws WriteScheduler, order []uint32) error { + for k, id := range order { + wr, ok := ws.Pop() + if !ok { + return fmt.Errorf("Pop[%d]: got ok=false, want %d (order=%v)", k, id, order) + } + if got := wr.StreamID(); got != id { + return fmt.Errorf("Pop[%d]: got %v, want %d (order=%v)", k, got, id, order) + } + } + wr, ok := ws.Pop() + if ok { + return fmt.Errorf("Pop[%d]: got %v, want ok=false (order=%v)", len(order), wr.StreamID(), order) + } + return nil +} + +func TestPriorityPopFrom533Tree(t *testing.T) { + ws := makeSection533Tree() + + ws.Push(makeWriteHeadersRequest(3 /*C*/)) + ws.Push(makeWriteNonStreamRequest()) + ws.Push(makeWriteHeadersRequest(5 /*E*/)) + ws.Push(makeWriteHeadersRequest(1 /*A*/)) + t.Log("tree:", fmtTree(ws, fmtNodeParentSkipRoot)) + + if err := checkPopAll(ws, []uint32{0 /*NonStream*/, 1, 3, 5}); err != nil { + t.Error(err) + } +} + +func TestPriorityPopFromLinearTree(t *testing.T) { + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + ws.OpenStream(3, OpenStreamOptions{PusherID: 2}) + ws.OpenStream(4, OpenStreamOptions{PusherID: 3}) + + ws.Push(makeWriteHeadersRequest(3)) + ws.Push(makeWriteHeadersRequest(4)) + ws.Push(makeWriteHeadersRequest(1)) + ws.Push(makeWriteHeadersRequest(2)) + ws.Push(makeWriteNonStreamRequest()) + ws.Push(makeWriteNonStreamRequest()) + t.Log("tree:", fmtTree(ws, fmtNodeParentSkipRoot)) + + if err := checkPopAll(ws, []uint32{0, 0 /*NonStreams*/, 1, 2, 3, 4}); err != nil { + t.Error(err) + } +} + +func TestPriorityFlowControl(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{ThrottleOutOfOrderWrites: false}) + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + + sc := &serverConn{maxFrameSize: 16} + st1 := &stream{id: 1, sc: sc} + st2 := &stream{id: 2, sc: sc} + + ws.Push(FrameWriteRequest{&writeData{1, make([]byte, 16), false}, st1, nil}) + ws.Push(FrameWriteRequest{&writeData{2, make([]byte, 16), false}, st2, nil}) + ws.AdjustStream(2, PriorityParam{StreamDep: 1}) + + // No flow-control bytes available. + if wr, ok := ws.Pop(); ok { + t.Fatalf("Pop(limited by flow control)=%v,true, want false", wr) + } + + // Add enough flow-control bytes to write st2 in two Pop calls. + // Should write data from st2 even though it's lower priority than st1. + for i := 1; i <= 2; i++ { + st2.flow.add(8) + wr, ok := ws.Pop() + if !ok { + t.Fatalf("Pop(%d)=false, want true", i) + } + if got, want := wr.DataSize(), 8; got != want { + t.Fatalf("Pop(%d)=%d bytes, want %d bytes", got, want) + } + } +} + +func TestPriorityThrottleOutOfOrderWrites(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{ThrottleOutOfOrderWrites: true}) + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) + + sc := &serverConn{maxFrameSize: 4096} + st1 := &stream{id: 1, sc: sc} + st2 := &stream{id: 2, sc: sc} + st1.flow.add(4096) + st2.flow.add(4096) + ws.Push(FrameWriteRequest{&writeData{2, make([]byte, 4096), false}, st2, nil}) + ws.AdjustStream(2, PriorityParam{StreamDep: 1}) + + // We have enough flow-control bytes to write st2 in a single Pop call. + // However, due to out-of-order write throttling, the first call should + // only write 1KB. + wr, ok := ws.Pop() + if !ok { + t.Fatalf("Pop(st2.first)=false, want true") + } + if got, want := wr.StreamID(), uint32(2); got != want { + t.Fatalf("Pop(st2.first)=stream %d, want stream %d", got, want) + } + if got, want := wr.DataSize(), 1024; got != want { + t.Fatalf("Pop(st2.first)=%d bytes, want %d bytes", got, want) + } + + // Now add data on st1. This should take precedence. + ws.Push(FrameWriteRequest{&writeData{1, make([]byte, 4096), false}, st1, nil}) + wr, ok = ws.Pop() + if !ok { + t.Fatalf("Pop(st1)=false, want true") + } + if got, want := wr.StreamID(), uint32(1); got != want { + t.Fatalf("Pop(st1)=stream %d, want stream %d", got, want) + } + if got, want := wr.DataSize(), 4096; got != want { + t.Fatalf("Pop(st1)=%d bytes, want %d bytes", got, want) + } + + // Should go back to writing 1KB from st2. + wr, ok = ws.Pop() + if !ok { + t.Fatalf("Pop(st2.last)=false, want true") + } + if got, want := wr.StreamID(), uint32(2); got != want { + t.Fatalf("Pop(st2.last)=stream %d, want stream %d", got, want) + } + if got, want := wr.DataSize(), 1024; got != want { + t.Fatalf("Pop(st2.last)=%d bytes, want %d bytes", got, want) + } +} + +func TestPriorityWeights(t *testing.T) { + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{}) + + sc := &serverConn{maxFrameSize: 8} + st1 := &stream{id: 1, sc: sc} + st2 := &stream{id: 2, sc: sc} + st1.flow.add(40) + st2.flow.add(40) + + ws.Push(FrameWriteRequest{&writeData{1, make([]byte, 40), false}, st1, nil}) + ws.Push(FrameWriteRequest{&writeData{2, make([]byte, 40), false}, st2, nil}) + ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 34}) + ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 9}) + + // st1 gets 3.5x the bandwidth of st2 (3.5 = (34+1)/(9+1)). + // The maximum frame size is 8 bytes. The write sequence should be: + // st1, total bytes so far is (st1=8, st=0) + // st2, total bytes so far is (st1=8, st=8) + // st1, total bytes so far is (st1=16, st=8) + // st1, total bytes so far is (st1=24, st=8) // 3x bandwidth + // st1, total bytes so far is (st1=32, st=8) // 4x bandwidth + // st2, total bytes so far is (st1=32, st=16) // 2x bandwidth + // st1, total bytes so far is (st1=40, st=16) + // st2, total bytes so far is (st1=40, st=24) + // st2, total bytes so far is (st1=40, st=32) + // st2, total bytes so far is (st1=40, st=40) + if err := checkPopAll(ws, []uint32{1, 2, 1, 1, 1, 2, 1, 2, 2, 2}); err != nil { + t.Error(err) + } +} + +func TestPriorityRstStreamOnNonOpenStreams(t *testing.T) { + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 0, + MaxIdleNodesInTree: 0, + }) + ws.OpenStream(1, OpenStreamOptions{}) + ws.CloseStream(1) + ws.Push(FrameWriteRequest{write: streamError(1, ErrCodeProtocol)}) + ws.Push(FrameWriteRequest{write: streamError(2, ErrCodeProtocol)}) + + if err := checkPopAll(ws, []uint32{1, 2}); err != nil { + t.Error(err) + } +} diff --git a/vendor/golang.org/x/net/http2/writesched_random.go b/vendor/golang.org/x/net/http2/writesched_random.go new file mode 100644 index 000000000..36d7919f1 --- /dev/null +++ b/vendor/golang.org/x/net/http2/writesched_random.go @@ -0,0 +1,72 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import "math" + +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func NewRandomWriteScheduler() WriteScheduler { + return &randomWriteScheduler{sq: make(map[uint32]*writeQueue)} +} + +type randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle or closed, it's deleted from the map. + sq map[uint32]*writeQueue + + // pool of empty queues for reuse. + queuePool writeQueuePool +} + +func (ws *randomWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { + // no-op: idle streams are not tracked +} + +func (ws *randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return + } + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *randomWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { + // no-op: priorities are ignored +} + +func (ws *randomWriteScheduler) Push(wr FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return + } + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *randomWriteScheduler) Pop() (FrameWriteRequest, bool) { + // Control frames first. + if !ws.zero.empty() { + return ws.zero.shift(), true + } + // Iterate over all non-idle streams until finding one that can be consumed. + for _, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + return wr, true + } + } + return FrameWriteRequest{}, false +} diff --git a/vendor/golang.org/x/net/http2/writesched_random_test.go b/vendor/golang.org/x/net/http2/writesched_random_test.go new file mode 100644 index 000000000..97b0bcdbf --- /dev/null +++ b/vendor/golang.org/x/net/http2/writesched_random_test.go @@ -0,0 +1,44 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import "testing" + +func TestRandomScheduler(t *testing.T) { + ws := NewRandomWriteScheduler() + ws.Push(makeWriteHeadersRequest(3)) + ws.Push(makeWriteHeadersRequest(4)) + ws.Push(makeWriteHeadersRequest(1)) + ws.Push(makeWriteHeadersRequest(2)) + ws.Push(makeWriteNonStreamRequest()) + ws.Push(makeWriteNonStreamRequest()) + + // Pop all frames. Should get the non-stream requests first, + // followed by the stream requests in any order. + var order []FrameWriteRequest + for { + wr, ok := ws.Pop() + if !ok { + break + } + order = append(order, wr) + } + t.Logf("got frames: %v", order) + if len(order) != 6 { + t.Fatalf("got %d frames, expected 6", len(order)) + } + if order[0].StreamID() != 0 || order[1].StreamID() != 0 { + t.Fatalf("expected non-stream frames first", order[0], order[1]) + } + got := make(map[uint32]bool) + for _, wr := range order[2:] { + got[wr.StreamID()] = true + } + for id := uint32(1); id <= 4; id++ { + if !got[id] { + t.Errorf("frame not found for stream %d", id) + } + } +} diff --git a/vendor/golang.org/x/net/http2/writesched_test.go b/vendor/golang.org/x/net/http2/writesched_test.go new file mode 100644 index 000000000..0807056bc --- /dev/null +++ b/vendor/golang.org/x/net/http2/writesched_test.go @@ -0,0 +1,125 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "fmt" + "math" + "reflect" + "testing" +) + +func makeWriteNonStreamRequest() FrameWriteRequest { + return FrameWriteRequest{writeSettingsAck{}, nil, nil} +} + +func makeWriteHeadersRequest(streamID uint32) FrameWriteRequest { + st := &stream{id: streamID} + return FrameWriteRequest{&writeResHeaders{streamID: streamID, httpResCode: 200}, st, nil} +} + +func checkConsume(wr FrameWriteRequest, nbytes int32, want []FrameWriteRequest) error { + consumed, rest, n := wr.Consume(nbytes) + var wantConsumed, wantRest FrameWriteRequest + switch len(want) { + case 0: + case 1: + wantConsumed = want[0] + case 2: + wantConsumed = want[0] + wantRest = want[1] + } + if !reflect.DeepEqual(consumed, wantConsumed) || !reflect.DeepEqual(rest, wantRest) || n != len(want) { + return fmt.Errorf("got %v, %v, %v\nwant %v, %v, %v", consumed, rest, n, wantConsumed, wantRest, len(want)) + } + return nil +} + +func TestFrameWriteRequestNonData(t *testing.T) { + wr := makeWriteNonStreamRequest() + if got, want := wr.DataSize(), 0; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // Non-DATA frames are always consumed whole. + if err := checkConsume(wr, 0, []FrameWriteRequest{wr}); err != nil { + t.Errorf("Consume:\n%v", err) + } +} + +func TestFrameWriteRequestData(t *testing.T) { + st := &stream{ + id: 1, + sc: &serverConn{maxFrameSize: 16}, + } + const size = 32 + wr := FrameWriteRequest{&writeData{st.id, make([]byte, size), true}, st, make(chan error)} + if got, want := wr.DataSize(), size; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // No flow-control bytes available: cannot consume anything. + if err := checkConsume(wr, math.MaxInt32, []FrameWriteRequest{}); err != nil { + t.Errorf("Consume(limited by flow control):\n%v", err) + } + + // Add enough flow-control bytes to consume the entire frame, + // but we're now restricted by st.sc.maxFrameSize. + st.flow.add(size) + want := []FrameWriteRequest{ + { + write: &writeData{st.id, make([]byte, st.sc.maxFrameSize), false}, + stream: st, + done: nil, + }, + { + write: &writeData{st.id, make([]byte, size-st.sc.maxFrameSize), true}, + stream: st, + done: wr.done, + }, + } + if err := checkConsume(wr, math.MaxInt32, want); err != nil { + t.Errorf("Consume(limited by maxFrameSize):\n%v", err) + } + rest := want[1] + + // Consume 8 bytes from the remaining frame. + want = []FrameWriteRequest{ + { + write: &writeData{st.id, make([]byte, 8), false}, + stream: st, + done: nil, + }, + { + write: &writeData{st.id, make([]byte, size-st.sc.maxFrameSize-8), true}, + stream: st, + done: wr.done, + }, + } + if err := checkConsume(rest, 8, want); err != nil { + t.Errorf("Consume(8):\n%v", err) + } + rest = want[1] + + // Consume all remaining bytes. + want = []FrameWriteRequest{ + { + write: &writeData{st.id, make([]byte, size-st.sc.maxFrameSize-8), true}, + stream: st, + done: wr.done, + }, + } + if err := checkConsume(rest, math.MaxInt32, want); err != nil { + t.Errorf("Consume(remainder):\n%v", err) + } +} + +func TestFrameWriteRequest_StreamID(t *testing.T) { + const streamID = 123 + wr := FrameWriteRequest{write: streamError(streamID, ErrCodeNo)} + if got := wr.StreamID(); got != streamID { + t.Errorf("FrameWriteRequest(StreamError) = %v; want %v", got, streamID) + } +} |