From 63d1a5100f828dc9a13255721322c46e87f8eca6 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 9 Oct 2025 14:23:05 +0530 Subject: [PATCH] http2: Allow reading frame header and body separately This change exports two new methods on the Framer, ReadFrameHeader and ReadFrameBodyForHeader, which split the functionality of the existing ReadFrame method. This provides more granular control, allowing callers to inspect the frame header before deciding whether or how to read the frame body. This is useful for applications that may need to make decisions based on frame type. Fixes golang/go#73560 Change-Id: I60b42d2889095fac8e243022886740bc6dd94012 Reviewed-on: https://go-review.googlesource.com/c/net/+/710515 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Auto-Submit: Damien Neil --- http2/frame.go | 74 ++++++++++++++++++++---------- http2/frame_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 24 deletions(-) diff --git a/http2/frame.go b/http2/frame.go index 93bcaab0..a7345a65 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -280,6 +280,8 @@ type Framer struct { // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. lastHeaderStream uint32 + // lastFrameType holds the type of the last frame for verifying frame order. + lastFrameType FrameType maxReadSize uint32 headerBuf [frameHeaderLen]byte @@ -488,30 +490,41 @@ func terminalReadFrameError(err error) bool { return err != nil } -// ReadFrame reads a single frame. The returned Frame is only valid -// until the next call to ReadFrame. +// ReadFrameHeader reads the header of the next frame. +// It reads the 9-byte fixed frame header, and does not read any portion of the +// frame payload. The caller is responsible for consuming the payload, either +// with ReadFrameForHeader or directly from the Framer's io.Reader. // -// If the frame is larger than previously set with SetMaxReadFrameSize, the -// returned error is ErrFrameTooLarge. Other errors may be of type -// ConnectionError, StreamError, or anything else from the underlying -// reader. +// If the frame is larger than previously set with SetMaxReadFrameSize, it +// returns the frame header and ErrFrameTooLarge. // -// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID -// indicates the stream responsible for the error. -func (fr *Framer) ReadFrame() (Frame, error) { +// If the returned FrameHeader.StreamID is non-zero, it indicates the stream +// responsible for the error. +func (fr *Framer) ReadFrameHeader() (FrameHeader, error) { fr.errDetail = nil - if fr.lastFrame != nil { - fr.lastFrame.invalidate() - } fh, err := readFrameHeader(fr.headerBuf[:], fr.r) if err != nil { - return nil, err + return fh, err } if fh.Length > fr.maxReadSize { if fh == invalidHTTP1LookingFrameHeader() { - return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge) + return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge) } - return nil, ErrFrameTooLarge + return fh, ErrFrameTooLarge + } + if err := fr.checkFrameOrder(fh); err != nil { + return fh, err + } + return fh, nil +} + +// ReadFrameForHeader reads the payload for the frame with the given FrameHeader. +// +// It behaves identically to ReadFrame, other than not checking the maximum +// frame size. +func (fr *Framer) ReadFrameForHeader(fh FrameHeader) (Frame, error) { + if fr.lastFrame != nil { + fr.lastFrame.invalidate() } payload := fr.getReadBuf(fh.Length) if _, err := io.ReadFull(fr.r, payload); err != nil { @@ -527,9 +540,7 @@ func (fr *Framer) ReadFrame() (Frame, error) { } return nil, err } - if err := fr.checkFrameOrder(f); err != nil { - return nil, err - } + fr.lastFrame = f if fr.logReads { fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f)) } @@ -539,6 +550,24 @@ func (fr *Framer) ReadFrame() (Frame, error) { return f, nil } +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame or ReadFrameBodyForHeader. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from the underlying +// reader. +// +// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID +// indicates the stream responsible for the error. +func (fr *Framer) ReadFrame() (Frame, error) { + fh, err := fr.ReadFrameHeader() + if err != nil { + return nil, err + } + return fr.ReadFrameForHeader(fh) +} + // connError returns ConnectionError(code) but first // stashes away a public reason to the caller can optionally relay it // to the peer before hanging up on them. This might help others debug @@ -551,20 +580,19 @@ func (fr *Framer) connError(code ErrCode, reason string) error { // checkFrameOrder reports an error if f is an invalid frame to return // next from ReadFrame. Mostly it checks whether HEADERS and // CONTINUATION frames are contiguous. -func (fr *Framer) checkFrameOrder(f Frame) error { - last := fr.lastFrame - fr.lastFrame = f +func (fr *Framer) checkFrameOrder(fh FrameHeader) error { + lastType := fr.lastFrameType + fr.lastFrameType = fh.Type if fr.AllowIllegalReads { return nil } - fh := f.Header() if fr.lastHeaderStream != 0 { if fh.Type != FrameContinuation { return fr.connError(ErrCodeProtocol, fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", fh.Type, fh.StreamID, - last.Header().Type, fr.lastHeaderStream)) + lastType, fr.lastHeaderStream)) } if fh.StreamID != fr.lastHeaderStream { return fr.connError(ErrCodeProtocol, diff --git a/http2/frame_test.go b/http2/frame_test.go index 6bf0026a..a2b136d1 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -825,7 +825,7 @@ func TestReadFrameOrder(t *testing.T) { }, }, 9: { - wantErr: "CONTINUATION frame with stream ID 0", + wantErr: "unexpected CONTINUATION for stream 0", w: func(f *Framer) { cont(f, 0, true) }, @@ -1278,3 +1278,110 @@ func TestTypeFrameParser(t *testing.T) { t.Errorf("expected UnknownFrame, got %T", frame) } } + +func TestReadFrameHeaderAndBody(t *testing.T) { + fr, _ := testFramer() + var streamID uint32 = 1 + data := []byte("ABC") + if err := fr.WriteData(streamID, true, data); err != nil { + t.Fatalf("WriteData(%d, true, %q) failed: %v", streamID, data, err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + wantHeader := FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream, + Length: 3, + StreamID: 1, + valid: true, + } + if !fh.Equal(wantHeader) { + t.Fatalf("ReadFrameHeader = %+v; want %+v", fh, wantHeader) + } + + f, err := fr.ReadFrameForHeader(fh) + if err != nil { + t.Fatalf("ReadFrameForHeader failed: %v", err) + } + + if !fh.Equal(f.Header()) { + t.Fatalf("Frame.Header() = %+v; want %+v", f.Header(), fh) + } + + df, ok := f.(*DataFrame) + if !ok { + t.Fatalf("got %T; want *DataFrame", f) + } + if got, want := df.Data(), data; !bytes.Equal(got, want) { + t.Errorf("DataFrame.Data() = %q; want %q", string(got), string(want)) + } + if got, want := df.StreamEnded(), true; got != want { + t.Errorf("DataFrame.StreamEnded() = %v; want %v", got, want) + } +} + +func TestReadFrameHeaderFrameTooLarge(t *testing.T) { + fr, _ := testFramer() + fr.SetMaxReadFrameSize(2) + if err := fr.WriteData(1, true, []byte("ABC")); err != nil { + t.Fatalf("WriteData failed: %v", err) + } + fh, err := fr.ReadFrameHeader() + if gotErr, wantErr := err, ErrFrameTooLarge; gotErr != wantErr { + t.Fatalf("ReadFrameHeader returned error %v; want %v", gotErr, wantErr) + } + if fh.StreamID != 1 { + t.Errorf("ReadFrameHeader = %v, %v; want StreamID 1", fh, err) + } +} + +func TestReadFrameHeaderBadFrameOrder(t *testing.T) { + fr, _ := testFramer() + if err := fr.WriteHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: []byte("foo"), // unused, but non-empty + EndHeaders: false, + }); err != nil { + t.Fatalf("WriteHeaders failed: %v", err) + } + + // Write a CONTINUATION frame for stream 2 without first finishing the headers for stream 1. + if err := fr.WriteContinuation(2, true, []byte("foo")); err != nil { + t.Fatalf("WriteContinuation failed: %v", err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + if _, err = fr.ReadFrameForHeader(fh); err != nil { + t.Fatalf("ReadFrameForHeader failed: %v", err) + } + + if _, err := fr.ReadFrameHeader(); err != ConnectionError(ErrCodeProtocol) { + t.Fatalf("ReadFrameHeader returned error %v; want ConnectionError(ErrCodeProtocol)", err) + } +} + +func TestReadFrameForHeaderUnexpectedEOF(t *testing.T) { + fr, b := testFramer() + if err := fr.WriteData(1, true, []byte("ABC")); err != nil { + t.Fatalf("WriteData failed: %v", err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + + // Remove one byte from the body, corrupting the frame body. + b.Truncate(b.Len() - 1) + + _, err = fr.ReadFrameForHeader(fh) + if err != io.ErrUnexpectedEOF { + t.Fatalf("ReadFrameForHeader with short body = %v; want io.ErrUnexpectedEOF", err) + } +}