From 00ed5e97ea3a5ac46658b98e50259941947cec04 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 15 Nov 2016 00:44:39 +0000 Subject: [PATCH] http2: schedule RSTStream writes onto its stream's queue Fixes golang/go#17243 Change-Id: I76f972f908757b103e2ab8d9b1701312308d66e5 Reviewed-on: https://go-review.googlesource.com/33238 Reviewed-by: Tom Bergan --- http2/writesched.go | 13 ++++++++----- http2/writesched_test.go | 8 ++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/http2/writesched.go b/http2/writesched.go index 9f3e1b32..fb5da35a 100644 --- a/http2/writesched.go +++ b/http2/writesched.go @@ -62,6 +62,13 @@ type FrameWriteRequest struct { // 0 is used for non-stream frames such as PING and SETTINGS. func (wr FrameWriteRequest) StreamID() uint32 { if wr.stream == nil { + if se, ok := wr.write.(StreamError); ok { + // (*serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. + return se.StreamID + } return 0 } return wr.stream.id @@ -142,17 +149,13 @@ func (wr FrameWriteRequest) Consume(n int32) (FrameWriteRequest, FrameWriteReque // String is for debugging only. func (wr FrameWriteRequest) String() string { - var streamID uint32 - if wr.stream != nil { - streamID = wr.stream.id - } var des string if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() } else { des = fmt.Sprintf("%T", wr.write) } - return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", streamID, wr.done != nil, des) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) } // writeQueue is used by implementations of WriteScheduler. diff --git a/http2/writesched_test.go b/http2/writesched_test.go index 10d2362a..0807056b 100644 --- a/http2/writesched_test.go +++ b/http2/writesched_test.go @@ -115,3 +115,11 @@ func TestFrameWriteRequestData(t *testing.T) { t.Errorf("Consume(remainder):\n%v", err) } } + +func TestFrameWriteRequest_StreamID(t *testing.T) { + const streamID = 123 + wr := FrameWriteRequest{write: streamError(streamID, ErrCodeNo)} + if got := wr.StreamID(); got != streamID { + t.Errorf("FrameWriteRequest(StreamError) = %v; want %v", got, streamID) + } +}