mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2/hpack: push down max string length checking further, improve docs
Change-Id: I875835875f8f97158f2dc88e508a075929af931e Reviewed-on: https://go-review.googlesource.com/15827 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
@@ -95,8 +95,8 @@ func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decod
|
||||
var ErrStringLength = errors.New("hpack: string too long")
|
||||
|
||||
// SetMaxStringLength sets the maximum size of a HeaderField name or
|
||||
// value string, after compression. If a string exceeds this length,
|
||||
// Write will return ErrStringLength.
|
||||
// value string. If a string exceeds this length (even after any
|
||||
// decompression), Write will return ErrStringLength.
|
||||
// A value of 0 means unlimited and is the default from NewDecoder.
|
||||
func (d *Decoder) SetMaxStringLength(n int) {
|
||||
d.maxStrLen = n
|
||||
@@ -281,16 +281,20 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
|
||||
|
||||
for len(d.buf) > 0 {
|
||||
err = d.parseHeaderFieldRepr()
|
||||
if err != nil {
|
||||
if err == errNeedMore {
|
||||
err = nil
|
||||
const varIntOverhead = 8 // conservative
|
||||
if d.maxStrLen != 0 &&
|
||||
int64(len(d.buf))+int64(d.saveBuf.Len()) > 2*(int64(d.maxStrLen)+varIntOverhead) {
|
||||
return 0, ErrStringLength
|
||||
}
|
||||
d.saveBuf.Write(d.buf)
|
||||
if err == errNeedMore {
|
||||
// Extra paranoia, making sure saveBuf won't
|
||||
// get too large. All the varint and string
|
||||
// reading code earlier should already catch
|
||||
// overlong things and return ErrStringLength,
|
||||
// but keep this as a last resort.
|
||||
const varIntOverhead = 8 // conservative
|
||||
if d.maxStrLen != 0 && int64(len(d.buf)) > 2*(int64(d.maxStrLen)+varIntOverhead) {
|
||||
return 0, ErrStringLength
|
||||
}
|
||||
d.saveBuf.Write(d.buf)
|
||||
return len(p), nil
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -382,12 +386,12 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
|
||||
}
|
||||
hf.Name = ihf.Name
|
||||
} else {
|
||||
hf.Name, buf, err = readString(buf, wantStr)
|
||||
hf.Name, buf, err = d.readString(buf, wantStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
hf.Value, buf, err = readString(buf, wantStr)
|
||||
hf.Value, buf, err = d.readString(buf, wantStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -477,7 +481,7 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
|
||||
// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server
|
||||
// is returning an error anyway, and because they're not indexed, the error
|
||||
// won't affect the decoding state.
|
||||
func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
|
||||
func (d *Decoder) readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
|
||||
if len(p) == 0 {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
@@ -486,6 +490,9 @@ func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
|
||||
if err != nil {
|
||||
return "", p, err
|
||||
}
|
||||
if d.maxStrLen != 0 && strLen > uint64(d.maxStrLen) {
|
||||
return "", nil, ErrStringLength
|
||||
}
|
||||
if uint64(len(p)) < strLen {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
@@ -497,10 +504,15 @@ func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
|
||||
}
|
||||
|
||||
if wantStr {
|
||||
s, err = HuffmanDecodeToString(p[:strLen])
|
||||
if err != nil {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset() // don't trust others
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, d.maxStrLen, p[:strLen]); err != nil {
|
||||
buf.Reset()
|
||||
return "", nil, err
|
||||
}
|
||||
s = buf.String()
|
||||
buf.Reset() // be nice to GC
|
||||
}
|
||||
return s, p[strLen:], nil
|
||||
}
|
||||
|
||||
@@ -583,6 +583,29 @@ func TestAppendHuffmanString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHuffmanMaxStrLen(t *testing.T) {
|
||||
const msg = "Some string"
|
||||
huff := AppendHuffmanString(nil, msg)
|
||||
|
||||
testGood := func(max int) {
|
||||
var out bytes.Buffer
|
||||
if err := huffmanDecode(&out, max, huff); err != nil {
|
||||
t.Errorf("For maxLen=%d, unexpected error: %v", max, err)
|
||||
}
|
||||
if out.String() != msg {
|
||||
t.Errorf("For maxLen=%d, out = %q; want %q", max, out.String(), msg)
|
||||
}
|
||||
}
|
||||
testGood(0)
|
||||
testGood(len(msg))
|
||||
testGood(len(msg) + 1)
|
||||
|
||||
var out bytes.Buffer
|
||||
if err := huffmanDecode(&out, len(msg)-1, huff); err != ErrStringLength {
|
||||
t.Errorf("err = %v; want ErrStringLength", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHuffmanRoundtripStress(t *testing.T) {
|
||||
const Len = 50 // of uncompressed string
|
||||
input := make([]byte, Len)
|
||||
@@ -604,7 +627,7 @@ func TestHuffmanRoundtripStress(t *testing.T) {
|
||||
huff = AppendHuffmanString(huff[:0], string(input))
|
||||
encSize += int64(len(huff))
|
||||
output.Reset()
|
||||
if err := huffmanDecode(&output, huff); err != nil {
|
||||
if err := huffmanDecode(&output, 0, huff); err != nil {
|
||||
t.Errorf("Failed to decode %q -> %q -> error %v", input, huff, err)
|
||||
continue
|
||||
}
|
||||
@@ -639,7 +662,7 @@ func TestHuffmanDecodeFuzz(t *testing.T) {
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
if err := huffmanDecode(&buf, zbuf.Bytes()); err != nil {
|
||||
if err := huffmanDecode(&buf, 0, zbuf.Bytes()); err != nil {
|
||||
if err == ErrInvalidHuffman {
|
||||
numFail++
|
||||
continue
|
||||
|
||||
@@ -22,7 +22,7 @@ func HuffmanDecode(w io.Writer, v []byte) (int, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, v); err != nil {
|
||||
if err := huffmanDecode(buf, 0, v); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.Write(buf.Bytes())
|
||||
@@ -33,7 +33,7 @@ func HuffmanDecodeToString(v []byte) (string, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, v); err != nil {
|
||||
if err := huffmanDecode(buf, 0, v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
@@ -43,7 +43,10 @@ func HuffmanDecodeToString(v []byte) (string, error) {
|
||||
// Huffman-encoded strings.
|
||||
var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
|
||||
|
||||
func huffmanDecode(buf *bytes.Buffer, v []byte) error {
|
||||
// huffmanDecode decodes v to buf.
|
||||
// If maxLen is greater than 0, attempts to write more to buf than
|
||||
// maxLen bytes will return ErrStringLength.
|
||||
func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
|
||||
n := rootHuffmanNode
|
||||
cur, nbits := uint(0), uint8(0)
|
||||
for _, b := range v {
|
||||
@@ -56,6 +59,9 @@ func huffmanDecode(buf *bytes.Buffer, v []byte) error {
|
||||
return ErrInvalidHuffman
|
||||
}
|
||||
if n.children == nil {
|
||||
if maxLen != 0 && buf.Len() == maxLen {
|
||||
return ErrStringLength
|
||||
}
|
||||
buf.WriteByte(n.sym)
|
||||
nbits -= n.codeLen
|
||||
n = rootHuffmanNode
|
||||
|
||||
Reference in New Issue
Block a user